refactor: update module docstrings for clarity and consistency
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""OpenClaw Mission Control backend application package."""
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""API router modules for the OpenClaw Mission Control backend."""
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
"""Agent-scoped API routes for board operations and gateway coordination."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlmodel import SQLModel, col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.api import agents as agents_api
|
||||
from app.api import approvals as approvals_api
|
||||
@@ -27,11 +28,7 @@ from app.integrations.openclaw_gateway import (
|
||||
openclaw_call,
|
||||
send_message,
|
||||
)
|
||||
from app.models.activity_events import ActivityEvent
|
||||
from app.models.agents import Agent
|
||||
from app.models.approvals import Approval
|
||||
from app.models.board_memory import BoardMemory
|
||||
from app.models.board_onboarding import BoardOnboardingSession
|
||||
from app.models.boards import Board
|
||||
from app.models.gateways import Gateway
|
||||
from app.models.task_dependencies import TaskDependency
|
||||
@@ -58,7 +55,13 @@ from app.schemas.gateway_coordination import (
|
||||
GatewayMainAskUserResponse,
|
||||
)
|
||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||
from app.schemas.tasks import TaskCommentCreate, TaskCommentRead, TaskCreate, TaskRead, TaskUpdate
|
||||
from app.schemas.tasks import (
|
||||
TaskCommentCreate,
|
||||
TaskCommentRead,
|
||||
TaskCreate,
|
||||
TaskRead,
|
||||
TaskUpdate,
|
||||
)
|
||||
from app.services.activity_log import record_activity
|
||||
from app.services.board_leads import ensure_board_lead_agent
|
||||
from app.services.task_dependencies import (
|
||||
@@ -67,11 +70,27 @@ from app.services.task_dependencies import (
|
||||
validate_dependency_update,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.activity_events import ActivityEvent
|
||||
from app.models.approvals import Approval
|
||||
from app.models.board_memory import BoardMemory
|
||||
from app.models.board_onboarding import BoardOnboardingSession
|
||||
|
||||
router = APIRouter(prefix="/agent", tags=["agent"])
|
||||
|
||||
_AGENT_SESSION_PREFIX = "agent:"
|
||||
_SESSION_KEY_PARTS_MIN = 2
|
||||
_LEAD_SESSION_KEY_MISSING = "Lead agent has no session key"
|
||||
SESSION_DEP = Depends(get_session)
|
||||
AGENT_CTX_DEP = Depends(get_agent_auth_context)
|
||||
BOARD_DEP = Depends(get_board_or_404)
|
||||
TASK_DEP = Depends(get_task_or_404)
|
||||
BOARD_ID_QUERY = Query(default=None)
|
||||
TASK_STATUS_QUERY = Query(default=None, alias="status")
|
||||
IS_CHAT_QUERY = Query(default=None)
|
||||
APPROVAL_STATUS_QUERY = Query(default=None, alias="status")
|
||||
|
||||
|
||||
def _gateway_agent_id(agent: Agent) -> str:
|
||||
@@ -87,6 +106,8 @@ def _gateway_agent_id(agent: Agent) -> str:
|
||||
|
||||
|
||||
class SoulUpdateRequest(SQLModel):
|
||||
"""Payload for updating an agent SOUL document."""
|
||||
|
||||
content: str
|
||||
source_url: str | None = None
|
||||
reason: str | None = None
|
||||
@@ -124,9 +145,12 @@ async def _require_gateway_main(
|
||||
session_key = (agent.openclaw_session_id or "").strip()
|
||||
if not session_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Agent missing session key"
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Agent missing session key",
|
||||
)
|
||||
gateway = await Gateway.objects.filter_by(main_session_key=session_key).first(session)
|
||||
gateway = await Gateway.objects.filter_by(main_session_key=session_key).first(
|
||||
session,
|
||||
)
|
||||
if gateway is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
@@ -148,7 +172,9 @@ async def _require_gateway_board(
|
||||
) -> Board:
|
||||
board = await Board.objects.by_id(board_id).first(session)
|
||||
if board is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Board not found",
|
||||
)
|
||||
if board.gateway_id != gateway.id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
return board
|
||||
@@ -156,9 +182,10 @@ async def _require_gateway_board(
|
||||
|
||||
@router.get("/boards", response_model=DefaultLimitOffsetPage[BoardRead])
|
||||
async def list_boards(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> DefaultLimitOffsetPage[BoardRead]:
|
||||
"""List boards visible to the authenticated agent."""
|
||||
statement = select(Board)
|
||||
if agent_ctx.agent.board_id:
|
||||
statement = statement.where(col(Board.id) == agent_ctx.agent.board_id)
|
||||
@@ -168,19 +195,21 @@ async def list_boards(
|
||||
|
||||
@router.get("/boards/{board_id}", response_model=BoardRead)
|
||||
def get_board(
|
||||
board: Board = Depends(get_board_or_404),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
board: Board = BOARD_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> Board:
|
||||
"""Return a board if the authenticated agent can access it."""
|
||||
_guard_board_access(agent_ctx, board)
|
||||
return board
|
||||
|
||||
|
||||
@router.get("/agents", response_model=DefaultLimitOffsetPage[AgentRead])
|
||||
async def list_agents(
|
||||
board_id: UUID | None = Query(default=None),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
board_id: UUID | None = BOARD_ID_QUERY,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> DefaultLimitOffsetPage[AgentRead]:
|
||||
"""List agents, optionally filtered to a board."""
|
||||
statement = select(Agent)
|
||||
if agent_ctx.agent.board_id:
|
||||
if board_id and board_id != agent_ctx.agent.board_id:
|
||||
@@ -188,13 +217,19 @@ async def list_agents(
|
||||
statement = statement.where(Agent.board_id == agent_ctx.agent.board_id)
|
||||
elif board_id:
|
||||
statement = statement.where(Agent.board_id == board_id)
|
||||
main_session_keys = await agents_api._get_gateway_main_session_keys(session)
|
||||
get_gateway_main_session_keys = (
|
||||
agents_api._get_gateway_main_session_keys # noqa: SLF001
|
||||
)
|
||||
to_agent_read = agents_api._to_agent_read # noqa: SLF001
|
||||
with_computed_status = agents_api._with_computed_status # noqa: SLF001
|
||||
|
||||
main_session_keys = await get_gateway_main_session_keys(session)
|
||||
statement = statement.order_by(col(Agent.created_at).desc())
|
||||
|
||||
def _transform(items: Sequence[Any]) -> Sequence[Any]:
|
||||
agents = cast(Sequence[Agent], items)
|
||||
return [
|
||||
agents_api._to_agent_read(agents_api._with_computed_status(agent), main_session_keys)
|
||||
to_agent_read(with_computed_status(agent), main_session_keys)
|
||||
for agent in agents
|
||||
]
|
||||
|
||||
@@ -202,14 +237,15 @@ async def list_agents(
|
||||
|
||||
|
||||
@router.get("/boards/{board_id}/tasks", response_model=DefaultLimitOffsetPage[TaskRead])
|
||||
async def list_tasks(
|
||||
status_filter: str | None = Query(default=None, alias="status"),
|
||||
async def list_tasks( # noqa: PLR0913
|
||||
status_filter: str | None = TASK_STATUS_QUERY,
|
||||
assigned_agent_id: UUID | None = None,
|
||||
unassigned: bool | None = None,
|
||||
board: Board = Depends(get_board_or_404),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
board: Board = BOARD_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> DefaultLimitOffsetPage[TaskRead]:
|
||||
"""List tasks on a board with optional status and assignment filters."""
|
||||
_guard_board_access(agent_ctx, board)
|
||||
return await tasks_api.list_tasks(
|
||||
status_filter=status_filter,
|
||||
@@ -224,10 +260,11 @@ async def list_tasks(
|
||||
@router.post("/boards/{board_id}/tasks", response_model=TaskRead)
|
||||
async def create_task(
|
||||
payload: TaskCreate,
|
||||
board: Board = Depends(get_board_or_404),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
board: Board = BOARD_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> TaskRead:
|
||||
"""Create a task on the board as the lead agent."""
|
||||
_guard_board_access(agent_ctx, board)
|
||||
if not agent_ctx.agent.is_board_lead:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
@@ -250,7 +287,9 @@ async def create_task(
|
||||
board_id=board.id,
|
||||
dependency_ids=normalized_deps,
|
||||
)
|
||||
blocked_by = blocked_by_dependency_ids(dependency_ids=normalized_deps, status_by_id=dep_status)
|
||||
blocked_by = blocked_by_dependency_ids(
|
||||
dependency_ids=normalized_deps, status_by_id=dep_status,
|
||||
)
|
||||
|
||||
if blocked_by and (task.assigned_agent_id is not None or task.status != "inbox"):
|
||||
raise HTTPException(
|
||||
@@ -280,7 +319,7 @@ async def create_task(
|
||||
board_id=board.id,
|
||||
task_id=task.id,
|
||||
depends_on_task_id=dep_id,
|
||||
)
|
||||
),
|
||||
)
|
||||
await session.commit()
|
||||
await session.refresh(task)
|
||||
@@ -293,9 +332,14 @@ async def create_task(
|
||||
)
|
||||
await session.commit()
|
||||
if task.assigned_agent_id:
|
||||
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
|
||||
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(
|
||||
session,
|
||||
)
|
||||
if assigned_agent:
|
||||
await tasks_api._notify_agent_on_task_assign(
|
||||
notify_agent_on_task_assign = (
|
||||
tasks_api._notify_agent_on_task_assign # noqa: SLF001
|
||||
)
|
||||
await notify_agent_on_task_assign(
|
||||
session=session,
|
||||
board=board,
|
||||
task=task,
|
||||
@@ -306,18 +350,23 @@ async def create_task(
|
||||
"depends_on_task_ids": normalized_deps,
|
||||
"blocked_by_task_ids": blocked_by,
|
||||
"is_blocked": bool(blocked_by),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/boards/{board_id}/tasks/{task_id}", response_model=TaskRead)
|
||||
async def update_task(
|
||||
payload: TaskUpdate,
|
||||
task: Task = Depends(get_task_or_404),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
task: Task = TASK_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> TaskRead:
|
||||
if agent_ctx.agent.board_id and task.board_id and agent_ctx.agent.board_id != task.board_id:
|
||||
"""Update a task after board-level access checks."""
|
||||
if (
|
||||
agent_ctx.agent.board_id
|
||||
and task.board_id
|
||||
and agent_ctx.agent.board_id != task.board_id
|
||||
):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
return await tasks_api.update_task(
|
||||
payload=payload,
|
||||
@@ -332,11 +381,16 @@ async def update_task(
|
||||
response_model=DefaultLimitOffsetPage[TaskCommentRead],
|
||||
)
|
||||
async def list_task_comments(
|
||||
task: Task = Depends(get_task_or_404),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
task: Task = TASK_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> DefaultLimitOffsetPage[TaskCommentRead]:
|
||||
if agent_ctx.agent.board_id and task.board_id and agent_ctx.agent.board_id != task.board_id:
|
||||
"""List comments for a task visible to the authenticated agent."""
|
||||
if (
|
||||
agent_ctx.agent.board_id
|
||||
and task.board_id
|
||||
and agent_ctx.agent.board_id != task.board_id
|
||||
):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
return await tasks_api.list_task_comments(
|
||||
task=task,
|
||||
@@ -344,14 +398,21 @@ async def list_task_comments(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/boards/{board_id}/tasks/{task_id}/comments", response_model=TaskCommentRead)
|
||||
@router.post(
|
||||
"/boards/{board_id}/tasks/{task_id}/comments", response_model=TaskCommentRead,
|
||||
)
|
||||
async def create_task_comment(
|
||||
payload: TaskCommentCreate,
|
||||
task: Task = Depends(get_task_or_404),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
task: Task = TASK_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> ActivityEvent:
|
||||
if agent_ctx.agent.board_id and task.board_id and agent_ctx.agent.board_id != task.board_id:
|
||||
"""Create a task comment on behalf of the authenticated agent."""
|
||||
if (
|
||||
agent_ctx.agent.board_id
|
||||
and task.board_id
|
||||
and agent_ctx.agent.board_id != task.board_id
|
||||
):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
return await tasks_api.create_task_comment(
|
||||
payload=payload,
|
||||
@@ -361,13 +422,16 @@ async def create_task_comment(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/boards/{board_id}/memory", response_model=DefaultLimitOffsetPage[BoardMemoryRead])
|
||||
@router.get(
|
||||
"/boards/{board_id}/memory", response_model=DefaultLimitOffsetPage[BoardMemoryRead],
|
||||
)
|
||||
async def list_board_memory(
|
||||
is_chat: bool | None = Query(default=None),
|
||||
board: Board = Depends(get_board_or_404),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
is_chat: bool | None = IS_CHAT_QUERY,
|
||||
board: Board = BOARD_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> DefaultLimitOffsetPage[BoardMemoryRead]:
|
||||
"""List board memory entries with optional chat filtering."""
|
||||
_guard_board_access(agent_ctx, board)
|
||||
return await board_memory_api.list_board_memory(
|
||||
is_chat=is_chat,
|
||||
@@ -380,10 +444,11 @@ async def list_board_memory(
|
||||
@router.post("/boards/{board_id}/memory", response_model=BoardMemoryRead)
|
||||
async def create_board_memory(
|
||||
payload: BoardMemoryCreate,
|
||||
board: Board = Depends(get_board_or_404),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
board: Board = BOARD_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> BoardMemory:
|
||||
"""Create a board memory entry."""
|
||||
_guard_board_access(agent_ctx, board)
|
||||
return await board_memory_api.create_board_memory(
|
||||
payload=payload,
|
||||
@@ -398,11 +463,12 @@ async def create_board_memory(
|
||||
response_model=DefaultLimitOffsetPage[ApprovalRead],
|
||||
)
|
||||
async def list_approvals(
|
||||
status_filter: ApprovalStatus | None = Query(default=None, alias="status"),
|
||||
board: Board = Depends(get_board_or_404),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
status_filter: ApprovalStatus | None = APPROVAL_STATUS_QUERY,
|
||||
board: Board = BOARD_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> DefaultLimitOffsetPage[ApprovalRead]:
|
||||
"""List approvals for a board."""
|
||||
_guard_board_access(agent_ctx, board)
|
||||
return await approvals_api.list_approvals(
|
||||
status_filter=status_filter,
|
||||
@@ -415,10 +481,11 @@ async def list_approvals(
|
||||
@router.post("/boards/{board_id}/approvals", response_model=ApprovalRead)
|
||||
async def create_approval(
|
||||
payload: ApprovalCreate,
|
||||
board: Board = Depends(get_board_or_404),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
board: Board = BOARD_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> Approval:
|
||||
"""Create a board approval request."""
|
||||
_guard_board_access(agent_ctx, board)
|
||||
return await approvals_api.create_approval(
|
||||
payload=payload,
|
||||
@@ -431,10 +498,11 @@ async def create_approval(
|
||||
@router.post("/boards/{board_id}/onboarding", response_model=BoardOnboardingRead)
|
||||
async def update_onboarding(
|
||||
payload: BoardOnboardingAgentUpdate,
|
||||
board: Board = Depends(get_board_or_404),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
board: Board = BOARD_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> BoardOnboardingSession:
|
||||
"""Apply onboarding updates for a board."""
|
||||
_guard_board_access(agent_ctx, board)
|
||||
return await onboarding_api.agent_onboarding_update(
|
||||
payload=payload,
|
||||
@@ -447,14 +515,17 @@ async def update_onboarding(
|
||||
@router.post("/agents", response_model=AgentRead)
|
||||
async def create_agent(
|
||||
payload: AgentCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> AgentRead:
|
||||
"""Create an agent on the caller's board."""
|
||||
if not agent_ctx.agent.is_board_lead:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
if not agent_ctx.agent.board_id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
payload = AgentCreate(**{**payload.model_dump(), "board_id": agent_ctx.agent.board_id})
|
||||
payload = AgentCreate(
|
||||
**{**payload.model_dump(), "board_id": agent_ctx.agent.board_id},
|
||||
)
|
||||
return await agents_api.create_agent(
|
||||
payload=payload,
|
||||
session=session,
|
||||
@@ -466,10 +537,11 @@ async def create_agent(
|
||||
async def nudge_agent(
|
||||
payload: AgentNudge,
|
||||
agent_id: str,
|
||||
board: Board = Depends(get_board_or_404),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
board: Board = BOARD_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> OkResponse:
|
||||
"""Send a direct nudge message to a board agent."""
|
||||
_guard_board_access(agent_ctx, board)
|
||||
if not agent_ctx.agent.is_board_lead:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
@@ -484,7 +556,9 @@ async def nudge_agent(
|
||||
message = payload.message
|
||||
config = await _gateway_config(session, board)
|
||||
try:
|
||||
await ensure_session(target.openclaw_session_id, config=config, label=target.name)
|
||||
await ensure_session(
|
||||
target.openclaw_session_id, config=config, label=target.name,
|
||||
)
|
||||
await send_message(
|
||||
message,
|
||||
session_key=target.openclaw_session_id,
|
||||
@@ -499,7 +573,9 @@ async def nudge_agent(
|
||||
agent_id=agent_ctx.agent.id,
|
||||
)
|
||||
await session.commit()
|
||||
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
||||
) from exc
|
||||
record_activity(
|
||||
session,
|
||||
event_type="agent.nudge.sent",
|
||||
@@ -513,9 +589,10 @@ async def nudge_agent(
|
||||
@router.post("/heartbeat", response_model=AgentRead)
|
||||
async def agent_heartbeat(
|
||||
payload: AgentHeartbeatCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> AgentRead:
|
||||
"""Record heartbeat status for the authenticated agent."""
|
||||
# Heartbeats must apply to the authenticated agent; agent names are not unique.
|
||||
return await agents_api.heartbeat_agent(
|
||||
agent_id=str(agent_ctx.agent.id),
|
||||
@@ -528,10 +605,11 @@ async def agent_heartbeat(
|
||||
@router.get("/boards/{board_id}/agents/{agent_id}/soul", response_model=str)
|
||||
async def get_agent_soul(
|
||||
agent_id: str,
|
||||
board: Board = Depends(get_board_or_404),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
board: Board = BOARD_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> str:
|
||||
"""Fetch the target agent's SOUL.md content from the gateway."""
|
||||
_guard_board_access(agent_ctx, board)
|
||||
if not agent_ctx.agent.is_board_lead and str(agent_ctx.agent.id) != agent_id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
@@ -547,7 +625,9 @@ async def get_agent_soul(
|
||||
config=config,
|
||||
)
|
||||
except OpenClawGatewayError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
||||
) from exc
|
||||
if isinstance(payload, str):
|
||||
return payload
|
||||
if isinstance(payload, dict):
|
||||
@@ -559,17 +639,20 @@ async def get_agent_soul(
|
||||
nested = file_obj.get("content")
|
||||
if isinstance(nested, str):
|
||||
return nested
|
||||
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="Invalid gateway response")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY, detail="Invalid gateway response",
|
||||
)
|
||||
|
||||
|
||||
@router.put("/boards/{board_id}/agents/{agent_id}/soul", response_model=OkResponse)
|
||||
async def update_agent_soul(
|
||||
agent_id: str,
|
||||
payload: SoulUpdateRequest,
|
||||
board: Board = Depends(get_board_or_404),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
board: Board = BOARD_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> OkResponse:
|
||||
"""Update an agent's SOUL.md content in DB and gateway."""
|
||||
_guard_board_access(agent_ctx, board)
|
||||
if not agent_ctx.agent.is_board_lead:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
@@ -597,7 +680,9 @@ async def update_agent_soul(
|
||||
config=config,
|
||||
)
|
||||
except OpenClawGatewayError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
||||
) from exc
|
||||
reason = (payload.reason or "").strip()
|
||||
source_url = (payload.source_url or "").strip()
|
||||
note = f"SOUL.md updated for {target.name}."
|
||||
@@ -621,10 +706,11 @@ async def update_agent_soul(
|
||||
)
|
||||
async def ask_user_via_gateway_main(
|
||||
payload: GatewayMainAskUserRequest,
|
||||
board: Board = Depends(get_board_or_404),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
board: Board = BOARD_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> GatewayMainAskUserResponse:
|
||||
"""Route a lead's ask-user request through the gateway main agent."""
|
||||
import json
|
||||
|
||||
_guard_board_access(agent_ctx, board)
|
||||
@@ -653,7 +739,9 @@ async def ask_user_via_gateway_main(
|
||||
correlation = payload.correlation_id.strip() if payload.correlation_id else ""
|
||||
correlation_line = f"Correlation ID: {correlation}\n" if correlation else ""
|
||||
preferred_channel = (payload.preferred_channel or "").strip()
|
||||
channel_line = f"Preferred channel: {preferred_channel}\n" if preferred_channel else ""
|
||||
channel_line = (
|
||||
f"Preferred channel: {preferred_channel}\n" if preferred_channel else ""
|
||||
)
|
||||
|
||||
tags = payload.reply_tags or ["gateway_main", "user_reply"]
|
||||
tags_json = json.dumps(tags)
|
||||
@@ -668,9 +756,12 @@ async def ask_user_via_gateway_main(
|
||||
f"{correlation_line}"
|
||||
f"{channel_line}\n"
|
||||
f"{payload.content.strip()}\n\n"
|
||||
"Please reach the user via your configured OpenClaw channel(s) (Slack/SMS/etc).\n"
|
||||
"If you cannot reach them there, post the question in Mission Control board chat as a fallback.\n\n"
|
||||
"When you receive the answer, reply in Mission Control by writing a NON-chat memory item on this board:\n"
|
||||
"Please reach the user via your configured OpenClaw channel(s) "
|
||||
"(Slack/SMS/etc).\n"
|
||||
"If you cannot reach them there, post the question in Mission Control "
|
||||
"board chat as a fallback.\n\n"
|
||||
"When you receive the answer, reply in Mission Control by writing a "
|
||||
"NON-chat memory item on this board:\n"
|
||||
f"POST {base_url}/api/v1/agent/boards/{board.id}/memory\n"
|
||||
f'Body: {{"content":"<answer>","tags":{tags_json},"source":"{reply_source}"}}\n'
|
||||
"Do NOT reply in OpenClaw chat."
|
||||
@@ -678,7 +769,9 @@ async def ask_user_via_gateway_main(
|
||||
|
||||
try:
|
||||
await ensure_session(main_session_key, config=config, label="Main Agent")
|
||||
await send_message(message, session_key=main_session_key, config=config, deliver=True)
|
||||
await send_message(
|
||||
message, session_key=main_session_key, config=config, deliver=True,
|
||||
)
|
||||
except OpenClawGatewayError as exc:
|
||||
record_activity(
|
||||
session,
|
||||
@@ -687,7 +780,9 @@ async def ask_user_via_gateway_main(
|
||||
agent_id=agent_ctx.agent.id,
|
||||
)
|
||||
await session.commit()
|
||||
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
||||
) from exc
|
||||
|
||||
record_activity(
|
||||
session,
|
||||
@@ -696,7 +791,9 @@ async def ask_user_via_gateway_main(
|
||||
agent_id=agent_ctx.agent.id,
|
||||
)
|
||||
|
||||
main_agent = await Agent.objects.filter_by(openclaw_session_id=main_session_key).first(session)
|
||||
main_agent = await Agent.objects.filter_by(
|
||||
openclaw_session_id=main_session_key,
|
||||
).first(session)
|
||||
|
||||
await session.commit()
|
||||
|
||||
@@ -714,9 +811,10 @@ async def ask_user_via_gateway_main(
|
||||
async def message_gateway_board_lead(
|
||||
board_id: UUID,
|
||||
payload: GatewayLeadMessageRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> GatewayLeadMessageResponse:
|
||||
"""Send a gateway-main message to a single board lead agent."""
|
||||
import json
|
||||
|
||||
gateway, config = await _require_gateway_main(session, agent_ctx.agent)
|
||||
@@ -736,7 +834,11 @@ async def message_gateway_board_lead(
|
||||
)
|
||||
|
||||
base_url = settings.base_url or "http://localhost:8000"
|
||||
header = "GATEWAY MAIN QUESTION" if payload.kind == "question" else "GATEWAY MAIN HANDOFF"
|
||||
header = (
|
||||
"GATEWAY MAIN QUESTION"
|
||||
if payload.kind == "question"
|
||||
else "GATEWAY MAIN HANDOFF"
|
||||
)
|
||||
correlation = payload.correlation_id.strip() if payload.correlation_id else ""
|
||||
correlation_line = f"Correlation ID: {correlation}\n" if correlation else ""
|
||||
tags = payload.reply_tags or ["gateway_main", "lead_reply"]
|
||||
@@ -767,7 +869,9 @@ async def message_gateway_board_lead(
|
||||
agent_id=agent_ctx.agent.id,
|
||||
)
|
||||
await session.commit()
|
||||
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
||||
) from exc
|
||||
|
||||
record_activity(
|
||||
session,
|
||||
@@ -791,9 +895,10 @@ async def message_gateway_board_lead(
|
||||
)
|
||||
async def broadcast_gateway_lead_message(
|
||||
payload: GatewayLeadBroadcastRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||
) -> GatewayLeadBroadcastResponse:
|
||||
"""Broadcast a gateway-main message to multiple board leads."""
|
||||
import json
|
||||
|
||||
gateway, config = await _require_gateway_main(session, agent_ctx.agent)
|
||||
@@ -808,7 +913,11 @@ async def broadcast_gateway_lead_message(
|
||||
boards = list(await session.exec(statement))
|
||||
|
||||
base_url = settings.base_url or "http://localhost:8000"
|
||||
header = "GATEWAY MAIN QUESTION" if payload.kind == "question" else "GATEWAY MAIN HANDOFF"
|
||||
header = (
|
||||
"GATEWAY MAIN QUESTION"
|
||||
if payload.kind == "question"
|
||||
else "GATEWAY MAIN HANDOFF"
|
||||
)
|
||||
correlation = payload.correlation_id.strip() if payload.correlation_id else ""
|
||||
correlation_line = f"Correlation ID: {correlation}\n" if correlation else ""
|
||||
tags = payload.reply_tags or ["gateway_main", "lead_reply"]
|
||||
@@ -819,7 +928,7 @@ async def broadcast_gateway_lead_message(
|
||||
sent = 0
|
||||
failed = 0
|
||||
|
||||
for board in boards:
|
||||
async def _send_to_board(board: Board) -> GatewayLeadBroadcastBoardResult:
|
||||
try:
|
||||
lead, _lead_created = await ensure_board_lead_agent(
|
||||
session,
|
||||
@@ -837,30 +946,34 @@ async def broadcast_gateway_lead_message(
|
||||
f"From agent: {agent_ctx.agent.name}\n"
|
||||
f"{correlation_line}\n"
|
||||
f"{payload.content.strip()}\n\n"
|
||||
"Reply to the gateway main by writing a NON-chat memory item on this board:\n"
|
||||
"Reply to the gateway main by writing a NON-chat memory item "
|
||||
"on this board:\n"
|
||||
f"POST {base_url}/api/v1/agent/boards/{board.id}/memory\n"
|
||||
f'Body: {{"content":"...","tags":{tags_json},"source":"{reply_source}"}}\n'
|
||||
f'Body: {{"content":"...","tags":{tags_json},'
|
||||
f'"source":"{reply_source}"}}\n'
|
||||
"Do NOT reply in OpenClaw chat."
|
||||
)
|
||||
await ensure_session(lead_session_key, config=config, label=lead.name)
|
||||
await send_message(message, session_key=lead_session_key, config=config)
|
||||
results.append(
|
||||
GatewayLeadBroadcastBoardResult(
|
||||
board_id=board.id,
|
||||
lead_agent_id=lead.id,
|
||||
lead_agent_name=lead.name,
|
||||
ok=True,
|
||||
)
|
||||
return GatewayLeadBroadcastBoardResult(
|
||||
board_id=board.id,
|
||||
lead_agent_id=lead.id,
|
||||
lead_agent_name=lead.name,
|
||||
ok=True,
|
||||
)
|
||||
sent += 1
|
||||
except (HTTPException, OpenClawGatewayError, ValueError) as exc:
|
||||
results.append(
|
||||
GatewayLeadBroadcastBoardResult(
|
||||
board_id=board.id,
|
||||
ok=False,
|
||||
error=str(exc),
|
||||
)
|
||||
return GatewayLeadBroadcastBoardResult(
|
||||
board_id=board.id,
|
||||
ok=False,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
for board in boards:
|
||||
board_result = await _send_to_board(board)
|
||||
results.append(board_result)
|
||||
if board_result.ok:
|
||||
sent += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
record_activity(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Agent lifecycle, listing, heartbeat, and deletion API endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
@@ -5,14 +7,12 @@ import json
|
||||
import re
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from sqlalchemy import asc, or_
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from app.api.deps import ActorContext, require_admin_or_agent, require_org_admin
|
||||
@@ -23,14 +23,17 @@ from app.db import crud
|
||||
from app.db.pagination import paginate
|
||||
from app.db.session import async_session_maker, get_session
|
||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
|
||||
from app.integrations.openclaw_gateway import (
|
||||
OpenClawGatewayError,
|
||||
ensure_session,
|
||||
send_message,
|
||||
)
|
||||
from app.models.activity_events import ActivityEvent
|
||||
from app.models.agents import Agent
|
||||
from app.models.boards import Board
|
||||
from app.models.gateways import Gateway
|
||||
from app.models.organizations import Organization
|
||||
from app.models.tasks import Task
|
||||
from app.models.users import User
|
||||
from app.schemas.agents import (
|
||||
AgentCreate,
|
||||
AgentHeartbeat,
|
||||
@@ -56,10 +59,23 @@ from app.services.organizations import (
|
||||
require_board_access,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.users import User
|
||||
|
||||
router = APIRouter(prefix="/agents", tags=["agents"])
|
||||
|
||||
OFFLINE_AFTER = timedelta(minutes=10)
|
||||
AGENT_SESSION_PREFIX = "agent"
|
||||
BOARD_ID_QUERY = Query(default=None)
|
||||
GATEWAY_ID_QUERY = Query(default=None)
|
||||
SINCE_QUERY = Query(default=None)
|
||||
SESSION_DEP = Depends(get_session)
|
||||
ORG_ADMIN_DEP = Depends(require_org_admin)
|
||||
ACTOR_DEP = Depends(require_admin_or_agent)
|
||||
AUTH_DEP = Depends(get_auth_context)
|
||||
|
||||
|
||||
def _parse_since(value: str | None) -> datetime | None:
|
||||
@@ -111,14 +127,16 @@ async def _require_board(
|
||||
)
|
||||
board = await Board.objects.by_id(board_id).first(session)
|
||||
if board is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Board not found",
|
||||
)
|
||||
if user is not None:
|
||||
await require_board_access(session, user=user, board=board, write=write)
|
||||
return board
|
||||
|
||||
|
||||
async def _require_gateway(
|
||||
session: AsyncSession, board: Board
|
||||
session: AsyncSession, board: Board,
|
||||
) -> tuple[Gateway, GatewayClientConfig]:
|
||||
if not board.gateway_id:
|
||||
raise HTTPException(
|
||||
@@ -169,16 +187,20 @@ async def _get_gateway_main_session_keys(session: AsyncSession) -> set[str]:
|
||||
|
||||
|
||||
def _is_gateway_main(agent: Agent, main_session_keys: set[str]) -> bool:
|
||||
return bool(agent.openclaw_session_id and agent.openclaw_session_id in main_session_keys)
|
||||
return bool(
|
||||
agent.openclaw_session_id and agent.openclaw_session_id in main_session_keys,
|
||||
)
|
||||
|
||||
|
||||
def _to_agent_read(agent: Agent, main_session_keys: set[str]) -> AgentRead:
|
||||
model = AgentRead.model_validate(agent, from_attributes=True)
|
||||
return model.model_copy(update={"is_gateway_main": _is_gateway_main(agent, main_session_keys)})
|
||||
return model.model_copy(
|
||||
update={"is_gateway_main": _is_gateway_main(agent, main_session_keys)},
|
||||
)
|
||||
|
||||
|
||||
async def _find_gateway_for_main_session(
|
||||
session: AsyncSession, session_key: str | None
|
||||
session: AsyncSession, session_key: str | None,
|
||||
) -> Gateway | None:
|
||||
if not session_key:
|
||||
return None
|
||||
@@ -210,7 +232,9 @@ def _with_computed_status(agent: Agent) -> Agent:
|
||||
|
||||
|
||||
def _serialize_agent(agent: Agent, main_session_keys: set[str]) -> dict[str, object]:
|
||||
return _to_agent_read(_with_computed_status(agent), main_session_keys).model_dump(mode="json")
|
||||
return _to_agent_read(_with_computed_status(agent), main_session_keys).model_dump(
|
||||
mode="json",
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_agent_events(
|
||||
@@ -225,18 +249,22 @@ async def _fetch_agent_events(
|
||||
or_(
|
||||
col(Agent.updated_at) >= since,
|
||||
col(Agent.last_seen_at) >= since,
|
||||
)
|
||||
),
|
||||
).order_by(asc(col(Agent.updated_at)))
|
||||
return list(await session.exec(statement))
|
||||
|
||||
|
||||
async def _require_user_context(session: AsyncSession, user: User | None) -> OrganizationContext:
|
||||
async def _require_user_context(
|
||||
session: AsyncSession, user: User | None,
|
||||
) -> OrganizationContext:
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
member = await get_active_membership(session, user)
|
||||
if member is None:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
organization = await Organization.objects.by_id(member.organization_id).first(session)
|
||||
organization = await Organization.objects.by_id(member.organization_id).first(
|
||||
session,
|
||||
)
|
||||
if organization is None:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
return OrganizationContext(organization=organization, member=member)
|
||||
@@ -252,7 +280,9 @@ async def _require_agent_access(
|
||||
if agent.board_id is None:
|
||||
if not is_org_admin(ctx.member):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
gateway = await _find_gateway_for_main_session(session, agent.openclaw_session_id)
|
||||
gateway = await _find_gateway_for_main_session(
|
||||
session, agent.openclaw_session_id,
|
||||
)
|
||||
if gateway is None or gateway.organization_id != ctx.organization.id:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
return
|
||||
@@ -274,7 +304,7 @@ def _record_heartbeat(session: AsyncSession, agent: Agent) -> None:
|
||||
|
||||
|
||||
def _record_instruction_failure(
|
||||
session: AsyncSession, agent: Agent, error: str, action: str
|
||||
session: AsyncSession, agent: Agent, error: str, action: str,
|
||||
) -> None:
|
||||
action_label = action.replace("_", " ").capitalize()
|
||||
record_activity(
|
||||
@@ -286,7 +316,7 @@ def _record_instruction_failure(
|
||||
|
||||
|
||||
async def _send_wakeup_message(
|
||||
agent: Agent, config: GatewayClientConfig, verb: str = "provisioned"
|
||||
agent: Agent, config: GatewayClientConfig, verb: str = "provisioned",
|
||||
) -> None:
|
||||
session_key = agent.openclaw_session_id or _build_session_key(agent.name)
|
||||
await ensure_session(session_key, config=config, label=agent.name)
|
||||
@@ -300,11 +330,12 @@ async def _send_wakeup_message(
|
||||
|
||||
@router.get("", response_model=DefaultLimitOffsetPage[AgentRead])
|
||||
async def list_agents(
|
||||
board_id: UUID | None = Query(default=None),
|
||||
gateway_id: UUID | None = Query(default=None),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
board_id: UUID | None = BOARD_ID_QUERY,
|
||||
gateway_id: UUID | None = GATEWAY_ID_QUERY,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> DefaultLimitOffsetPage[AgentRead]:
|
||||
"""List agents visible to the active organization admin."""
|
||||
main_session_keys = await _get_gateway_main_session_keys(session)
|
||||
board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False)
|
||||
if board_id is not None and board_id not in set(board_ids):
|
||||
@@ -315,9 +346,11 @@ async def list_agents(
|
||||
base_filter: ColumnElement[bool] = col(Agent.board_id).in_(board_ids)
|
||||
if is_org_admin(ctx.member):
|
||||
gateway_keys = select(Gateway.main_session_key).where(
|
||||
col(Gateway.organization_id) == ctx.organization.id
|
||||
col(Gateway.organization_id) == ctx.organization.id,
|
||||
)
|
||||
base_filter = or_(
|
||||
base_filter, col(Agent.openclaw_session_id).in_(gateway_keys),
|
||||
)
|
||||
base_filter = or_(base_filter, col(Agent.openclaw_session_id).in_(gateway_keys))
|
||||
statement = select(Agent).where(base_filter)
|
||||
if board_id is not None:
|
||||
statement = statement.where(col(Agent.board_id) == board_id)
|
||||
@@ -326,13 +359,16 @@ async def list_agents(
|
||||
if gateway is None or gateway.organization_id != ctx.organization.id:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
statement = statement.join(Board, col(Agent.board_id) == col(Board.id)).where(
|
||||
col(Board.gateway_id) == gateway_id
|
||||
col(Board.gateway_id) == gateway_id,
|
||||
)
|
||||
statement = statement.order_by(col(Agent.created_at).desc())
|
||||
|
||||
def _transform(items: Sequence[Any]) -> Sequence[Any]:
|
||||
agents = cast(Sequence[Agent], items)
|
||||
return [_to_agent_read(_with_computed_status(agent), main_session_keys) for agent in agents]
|
||||
return [
|
||||
_to_agent_read(_with_computed_status(agent), main_session_keys)
|
||||
for agent in agents
|
||||
]
|
||||
|
||||
return await paginate(session, statement, transformer=_transform)
|
||||
|
||||
@@ -340,11 +376,12 @@ async def list_agents(
|
||||
@router.get("/stream")
|
||||
async def stream_agents(
|
||||
request: Request,
|
||||
board_id: UUID | None = Query(default=None),
|
||||
since: str | None = Query(default=None),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
board_id: UUID | None = BOARD_ID_QUERY,
|
||||
since: str | None = SINCE_QUERY,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> EventSourceResponse:
|
||||
"""Stream agent updates as SSE events."""
|
||||
since_dt = _parse_since(since) or utcnow()
|
||||
last_seen = since_dt
|
||||
board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False)
|
||||
@@ -359,14 +396,20 @@ async def stream_agents(
|
||||
break
|
||||
async with async_session_maker() as stream_session:
|
||||
if board_id is not None:
|
||||
agents = await _fetch_agent_events(stream_session, board_id, last_seen)
|
||||
agents = await _fetch_agent_events(
|
||||
stream_session, board_id, last_seen,
|
||||
)
|
||||
elif allowed_ids:
|
||||
agents = await _fetch_agent_events(stream_session, None, last_seen)
|
||||
agents = [agent for agent in agents if agent.board_id in allowed_ids]
|
||||
agents = [
|
||||
agent for agent in agents if agent.board_id in allowed_ids
|
||||
]
|
||||
else:
|
||||
agents = []
|
||||
main_session_keys = (
|
||||
await _get_gateway_main_session_keys(stream_session) if agents else set()
|
||||
await _get_gateway_main_session_keys(stream_session)
|
||||
if agents
|
||||
else set()
|
||||
)
|
||||
for agent in agents:
|
||||
updated_at = agent.updated_at or agent.last_seen_at or utcnow()
|
||||
@@ -379,11 +422,12 @@ async def stream_agents(
|
||||
|
||||
|
||||
@router.post("", response_model=AgentRead)
|
||||
async def create_agent(
|
||||
async def create_agent( # noqa: C901, PLR0912, PLR0915
|
||||
payload: AgentCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
actor: ActorContext = Depends(require_admin_or_agent),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
actor: ActorContext = ACTOR_DEP,
|
||||
) -> AgentRead:
|
||||
"""Create and provision an agent."""
|
||||
if actor.actor_type == "user":
|
||||
ctx = await _require_user_context(session, actor.user)
|
||||
if not is_org_admin(ctx.member):
|
||||
@@ -404,7 +448,9 @@ async def create_agent(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Board leads can only create agents in their own board",
|
||||
)
|
||||
payload = AgentCreate(**{**payload.model_dump(), "board_id": actor.agent.board_id})
|
||||
payload = AgentCreate(
|
||||
**{**payload.model_dump(), "board_id": actor.agent.board_id},
|
||||
)
|
||||
|
||||
board = await _require_board(
|
||||
session,
|
||||
@@ -420,7 +466,7 @@ async def create_agent(
|
||||
await session.exec(
|
||||
select(Agent)
|
||||
.where(Agent.board_id == board.id)
|
||||
.where(col(Agent.name).ilike(requested_name))
|
||||
.where(col(Agent.name).ilike(requested_name)),
|
||||
)
|
||||
).first()
|
||||
if existing:
|
||||
@@ -428,20 +474,23 @@ async def create_agent(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="An agent with this name already exists on this board.",
|
||||
)
|
||||
# Prevent OpenClaw session/workspace collisions by enforcing uniqueness within
|
||||
# the gateway workspace too (agents on other boards share the same gateway root).
|
||||
# Prevent session/workspace collisions inside the gateway workspace.
|
||||
# Agents on different boards can still share one gateway root.
|
||||
existing_gateway = (
|
||||
await session.exec(
|
||||
select(Agent)
|
||||
.join(Board, col(Agent.board_id) == col(Board.id))
|
||||
.where(col(Board.gateway_id) == gateway.id)
|
||||
.where(col(Agent.name).ilike(requested_name))
|
||||
.where(col(Agent.name).ilike(requested_name)),
|
||||
)
|
||||
).first()
|
||||
if existing_gateway:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="An agent with this name already exists in this gateway workspace.",
|
||||
detail=(
|
||||
"An agent with this name already exists in this gateway "
|
||||
"workspace."
|
||||
),
|
||||
)
|
||||
desired_session_key = _build_session_key(requested_name)
|
||||
existing_session_key = (
|
||||
@@ -449,13 +498,16 @@ async def create_agent(
|
||||
select(Agent)
|
||||
.join(Board, col(Agent.board_id) == col(Board.id))
|
||||
.where(col(Board.gateway_id) == gateway.id)
|
||||
.where(col(Agent.openclaw_session_id) == desired_session_key)
|
||||
.where(col(Agent.openclaw_session_id) == desired_session_key),
|
||||
)
|
||||
).first()
|
||||
if existing_session_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="This agent name would collide with an existing workspace session key. Pick a different name.",
|
||||
detail=(
|
||||
"This agent name would collide with an existing workspace "
|
||||
"session key. Pick a different name."
|
||||
),
|
||||
)
|
||||
agent = Agent.model_validate(data)
|
||||
agent.status = "provisioning"
|
||||
@@ -465,7 +517,9 @@ async def create_agent(
|
||||
agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy()
|
||||
agent.provision_requested_at = utcnow()
|
||||
agent.provision_action = "provision"
|
||||
session_key, session_error = await _ensure_gateway_session(agent.name, client_config)
|
||||
session_key, session_error = await _ensure_gateway_session(
|
||||
agent.name, client_config,
|
||||
)
|
||||
agent.openclaw_session_id = session_key
|
||||
session.add(agent)
|
||||
await session.commit()
|
||||
@@ -527,9 +581,10 @@ async def create_agent(
|
||||
@router.get("/{agent_id}", response_model=AgentRead)
|
||||
async def get_agent(
|
||||
agent_id: str,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> AgentRead:
|
||||
"""Get a single agent by id."""
|
||||
agent = await Agent.objects.by_id(agent_id).first(session)
|
||||
if agent is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
@@ -539,14 +594,16 @@ async def get_agent(
|
||||
|
||||
|
||||
@router.patch("/{agent_id}", response_model=AgentRead)
|
||||
async def update_agent(
|
||||
async def update_agent( # noqa: C901, PLR0912, PLR0913, PLR0915
|
||||
agent_id: str,
|
||||
payload: AgentUpdate,
|
||||
*,
|
||||
force: bool = False,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
auth: AuthContext = AUTH_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> AgentRead:
|
||||
"""Update agent metadata and optionally reprovision."""
|
||||
agent = await Agent.objects.by_id(agent_id).first(session)
|
||||
if agent is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
@@ -564,12 +621,16 @@ async def update_agent(
|
||||
new_board = await _require_board(session, updates["board_id"])
|
||||
if new_board.organization_id != ctx.organization.id:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
if not await has_board_access(session, member=ctx.member, board=new_board, write=True):
|
||||
if not await has_board_access(
|
||||
session, member=ctx.member, board=new_board, write=True,
|
||||
):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
if not updates and not force and make_main is None:
|
||||
main_session_keys = await _get_gateway_main_session_keys(session)
|
||||
return _to_agent_read(_with_computed_status(agent), main_session_keys)
|
||||
main_gateway = await _find_gateway_for_main_session(session, agent.openclaw_session_id)
|
||||
main_gateway = await _find_gateway_for_main_session(
|
||||
session, agent.openclaw_session_id,
|
||||
)
|
||||
gateway_for_main: Gateway | None = None
|
||||
if make_main is True:
|
||||
board_source = updates.get("board_id") or agent.board_id
|
||||
@@ -723,9 +784,10 @@ async def update_agent(
|
||||
async def heartbeat_agent(
|
||||
agent_id: str,
|
||||
payload: AgentHeartbeat,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
actor: ActorContext = Depends(require_admin_or_agent),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
actor: ActorContext = ACTOR_DEP,
|
||||
) -> AgentRead:
|
||||
"""Record a heartbeat for a specific agent."""
|
||||
agent = await Agent.objects.by_id(agent_id).first(session)
|
||||
if agent is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
@@ -751,12 +813,14 @@ async def heartbeat_agent(
|
||||
|
||||
|
||||
@router.post("/heartbeat", response_model=AgentRead)
|
||||
async def heartbeat_or_create_agent(
|
||||
async def heartbeat_or_create_agent( # noqa: C901, PLR0912, PLR0915
|
||||
payload: AgentHeartbeatCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
actor: ActorContext = Depends(require_admin_or_agent),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
actor: ActorContext = ACTOR_DEP,
|
||||
) -> AgentRead:
|
||||
# Agent tokens must heartbeat their authenticated agent record. Names are not unique.
|
||||
"""Heartbeat an existing agent or create/provision one if needed."""
|
||||
# Agent tokens must heartbeat their authenticated agent record.
|
||||
# Names are not unique.
|
||||
if actor.actor_type == "agent" and actor.agent:
|
||||
return await heartbeat_agent(
|
||||
agent_id=str(actor.agent.id),
|
||||
@@ -793,7 +857,9 @@ async def heartbeat_or_create_agent(
|
||||
agent.agent_token_hash = hash_agent_token(raw_token)
|
||||
agent.provision_requested_at = utcnow()
|
||||
agent.provision_action = "provision"
|
||||
session_key, session_error = await _ensure_gateway_session(agent.name, client_config)
|
||||
session_key, session_error = await _ensure_gateway_session(
|
||||
agent.name, client_config,
|
||||
)
|
||||
agent.openclaw_session_id = session_key
|
||||
session.add(agent)
|
||||
await session.commit()
|
||||
@@ -814,7 +880,9 @@ async def heartbeat_or_create_agent(
|
||||
)
|
||||
await session.commit()
|
||||
try:
|
||||
await provision_agent(agent, board, gateway, raw_token, actor.user, action="provision")
|
||||
await provision_agent(
|
||||
agent, board, gateway, raw_token, actor.user, action="provision",
|
||||
)
|
||||
await _send_wakeup_message(agent, client_config, verb="provisioned")
|
||||
agent.provision_confirm_token_hash = None
|
||||
agent.provision_requested_at = None
|
||||
@@ -864,7 +932,7 @@ async def heartbeat_or_create_agent(
|
||||
)
|
||||
gateway, client_config = await _require_gateway(session, board)
|
||||
await provision_agent(
|
||||
agent, board, gateway, raw_token, actor.user, action="provision"
|
||||
agent, board, gateway, raw_token, actor.user, action="provision",
|
||||
)
|
||||
await _send_wakeup_message(agent, client_config, verb="provisioned")
|
||||
agent.provision_confirm_token_hash = None
|
||||
@@ -903,7 +971,9 @@ async def heartbeat_or_create_agent(
|
||||
write=actor.actor_type == "user",
|
||||
)
|
||||
gateway, client_config = await _require_gateway(session, board)
|
||||
session_key, session_error = await _ensure_gateway_session(agent.name, client_config)
|
||||
session_key, session_error = await _ensure_gateway_session(
|
||||
agent.name, client_config,
|
||||
)
|
||||
agent.openclaw_session_id = session_key
|
||||
if session_error:
|
||||
record_activity(
|
||||
@@ -937,15 +1007,18 @@ async def heartbeat_or_create_agent(
|
||||
@router.delete("/{agent_id}", response_model=OkResponse)
|
||||
async def delete_agent(
|
||||
agent_id: str,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> OkResponse:
|
||||
"""Delete an agent and clean related task state."""
|
||||
agent = await Agent.objects.by_id(agent_id).first(session)
|
||||
if agent is None:
|
||||
return OkResponse()
|
||||
await _require_agent_access(session, agent=agent, ctx=ctx, write=True)
|
||||
|
||||
board = await _require_board(session, str(agent.board_id) if agent.board_id else None)
|
||||
board = await _require_board(
|
||||
session, str(agent.board_id) if agent.board_id else None,
|
||||
)
|
||||
gateway, client_config = await _require_gateway(session, board)
|
||||
try:
|
||||
workspace_path = await cleanup_agent(agent, gateway)
|
||||
@@ -970,7 +1043,7 @@ async def delete_agent(
|
||||
message=f"Deleted agent {agent.name}.",
|
||||
agent_id=None,
|
||||
)
|
||||
now = datetime.now()
|
||||
now = utcnow()
|
||||
await crud.update_where(
|
||||
session,
|
||||
Task,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Authentication bootstrap endpoints for the Mission Control API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
@@ -6,10 +8,12 @@ from app.core.auth import AuthContext, get_auth_context
|
||||
from app.schemas.users import UserRead
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
AUTH_CONTEXT_DEP = Depends(get_auth_context)
|
||||
|
||||
|
||||
@router.post("/bootstrap", response_model=UserRead)
|
||||
async def bootstrap_user(auth: AuthContext = Depends(get_auth_context)) -> UserRead:
|
||||
async def bootstrap_user(auth: AuthContext = AUTH_CONTEXT_DEP) -> UserRead:
|
||||
"""Return the authenticated user profile from token claims."""
|
||||
if auth.actor_type != "user" or auth.user is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
return UserRead.model_validate(auth.user)
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
"""Board onboarding endpoints for user/agent collaboration."""
|
||||
# ruff: noqa: E501
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import ValidationError
|
||||
from sqlmodel import col
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.api.deps import (
|
||||
ActorContext,
|
||||
@@ -18,15 +21,17 @@ from app.api.deps import (
|
||||
require_admin_or_agent,
|
||||
)
|
||||
from app.core.agent_tokens import generate_agent_token, hash_agent_token
|
||||
from app.core.auth import AuthContext
|
||||
from app.core.config import settings
|
||||
from app.core.time import utcnow
|
||||
from app.db.session import get_session
|
||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
|
||||
from app.integrations.openclaw_gateway import (
|
||||
OpenClawGatewayError,
|
||||
ensure_session,
|
||||
send_message,
|
||||
)
|
||||
from app.models.agents import Agent
|
||||
from app.models.board_onboarding import BoardOnboardingSession
|
||||
from app.models.boards import Board
|
||||
from app.models.gateways import Gateway
|
||||
from app.schemas.board_onboarding import (
|
||||
BoardOnboardingAgentComplete,
|
||||
@@ -41,12 +46,24 @@ from app.schemas.board_onboarding import (
|
||||
from app.schemas.boards import BoardRead
|
||||
from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_agent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.auth import AuthContext
|
||||
from app.models.boards import Board
|
||||
|
||||
router = APIRouter(prefix="/boards/{board_id}/onboarding", tags=["board-onboarding"])
|
||||
logger = logging.getLogger(__name__)
|
||||
BOARD_USER_READ_DEP = Depends(get_board_for_user_read)
|
||||
BOARD_USER_WRITE_DEP = Depends(get_board_for_user_write)
|
||||
BOARD_OR_404_DEP = Depends(get_board_or_404)
|
||||
SESSION_DEP = Depends(get_session)
|
||||
ACTOR_DEP = Depends(require_admin_or_agent)
|
||||
ADMIN_AUTH_DEP = Depends(require_admin_auth)
|
||||
|
||||
|
||||
async def _gateway_config(
|
||||
session: AsyncSession, board: Board
|
||||
session: AsyncSession, board: Board,
|
||||
) -> tuple[Gateway, GatewayClientConfig]:
|
||||
if not board.gateway_id:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
@@ -61,7 +78,7 @@ def _build_session_key(agent_name: str) -> str:
|
||||
return f"agent:{slug or uuid4().hex}:main"
|
||||
|
||||
|
||||
def _lead_agent_name(board: Board) -> str:
|
||||
def _lead_agent_name(_board: Board) -> str:
|
||||
return "Lead Agent"
|
||||
|
||||
|
||||
@@ -69,7 +86,7 @@ def _lead_session_key(board: Board) -> str:
|
||||
return f"agent:lead-{board.id}:main"
|
||||
|
||||
|
||||
async def _ensure_lead_agent(
|
||||
async def _ensure_lead_agent( # noqa: PLR0913
|
||||
session: AsyncSession,
|
||||
board: Board,
|
||||
gateway: Gateway,
|
||||
@@ -100,7 +117,11 @@ async def _ensure_lead_agent(
|
||||
}
|
||||
if identity_profile:
|
||||
merged_identity_profile.update(
|
||||
{key: value.strip() for key, value in identity_profile.items() if value.strip()}
|
||||
{
|
||||
key: value.strip()
|
||||
for key, value in identity_profile.items()
|
||||
if value.strip()
|
||||
},
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
@@ -121,7 +142,9 @@ async def _ensure_lead_agent(
|
||||
await session.refresh(agent)
|
||||
|
||||
try:
|
||||
await provision_agent(agent, board, gateway, raw_token, auth.user, action="provision")
|
||||
await provision_agent(
|
||||
agent, board, gateway, raw_token, auth.user, action="provision",
|
||||
)
|
||||
await ensure_session(agent.openclaw_session_id, config=config, label=agent.name)
|
||||
await send_message(
|
||||
(
|
||||
@@ -141,9 +164,10 @@ async def _ensure_lead_agent(
|
||||
|
||||
@router.get("", response_model=BoardOnboardingRead)
|
||||
async def get_onboarding(
|
||||
board: Board = Depends(get_board_for_user_read),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
board: Board = BOARD_USER_READ_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
) -> BoardOnboardingSession:
|
||||
"""Get the latest onboarding session for a board."""
|
||||
onboarding = (
|
||||
await BoardOnboardingSession.objects.filter_by(board_id=board.id)
|
||||
.order_by(col(BoardOnboardingSession.updated_at).desc())
|
||||
@@ -156,10 +180,11 @@ async def get_onboarding(
|
||||
|
||||
@router.post("/start", response_model=BoardOnboardingRead)
|
||||
async def start_onboarding(
|
||||
payload: BoardOnboardingStart,
|
||||
board: Board = Depends(get_board_for_user_write),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_payload: BoardOnboardingStart,
|
||||
board: Board = BOARD_USER_WRITE_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
) -> BoardOnboardingSession:
|
||||
"""Start onboarding and send instructions to the gateway main agent."""
|
||||
onboarding = (
|
||||
await BoardOnboardingSession.objects.filter_by(board_id=board.id)
|
||||
.filter(col(BoardOnboardingSession.status) == "active")
|
||||
@@ -219,15 +244,21 @@ async def start_onboarding(
|
||||
|
||||
try:
|
||||
await ensure_session(session_key, config=config, label="Main Agent")
|
||||
await send_message(prompt, session_key=session_key, config=config, deliver=False)
|
||||
await send_message(
|
||||
prompt, session_key=session_key, config=config, deliver=False,
|
||||
)
|
||||
except OpenClawGatewayError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
||||
) from exc
|
||||
|
||||
onboarding = BoardOnboardingSession(
|
||||
board_id=board.id,
|
||||
session_key=session_key,
|
||||
status="active",
|
||||
messages=[{"role": "user", "content": prompt, "timestamp": utcnow().isoformat()}],
|
||||
messages=[
|
||||
{"role": "user", "content": prompt, "timestamp": utcnow().isoformat()},
|
||||
],
|
||||
)
|
||||
session.add(onboarding)
|
||||
await session.commit()
|
||||
@@ -238,9 +269,10 @@ async def start_onboarding(
|
||||
@router.post("/answer", response_model=BoardOnboardingRead)
|
||||
async def answer_onboarding(
|
||||
payload: BoardOnboardingAnswer,
|
||||
board: Board = Depends(get_board_for_user_write),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
board: Board = BOARD_USER_WRITE_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
) -> BoardOnboardingSession:
|
||||
"""Send a user onboarding answer to the gateway main agent."""
|
||||
onboarding = (
|
||||
await BoardOnboardingSession.objects.filter_by(board_id=board.id)
|
||||
.order_by(col(BoardOnboardingSession.updated_at).desc())
|
||||
@@ -255,15 +287,22 @@ async def answer_onboarding(
|
||||
answer_text = f"{payload.answer}: {payload.other_text}"
|
||||
|
||||
messages = list(onboarding.messages or [])
|
||||
messages.append({"role": "user", "content": answer_text, "timestamp": utcnow().isoformat()})
|
||||
messages.append(
|
||||
{"role": "user", "content": answer_text, "timestamp": utcnow().isoformat()},
|
||||
)
|
||||
|
||||
try:
|
||||
await ensure_session(onboarding.session_key, config=config, label="Main Agent")
|
||||
await send_message(
|
||||
answer_text, session_key=onboarding.session_key, config=config, deliver=False
|
||||
answer_text,
|
||||
session_key=onboarding.session_key,
|
||||
config=config,
|
||||
deliver=False,
|
||||
)
|
||||
except OpenClawGatewayError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
||||
) from exc
|
||||
|
||||
onboarding.messages = messages
|
||||
onboarding.updated_at = utcnow()
|
||||
@@ -276,10 +315,11 @@ async def answer_onboarding(
|
||||
@router.post("/agent", response_model=BoardOnboardingRead)
|
||||
async def agent_onboarding_update(
|
||||
payload: BoardOnboardingAgentUpdate,
|
||||
board: Board = Depends(get_board_or_404),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
actor: ActorContext = Depends(require_admin_or_agent),
|
||||
board: Board = BOARD_OR_404_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
actor: ActorContext = ACTOR_DEP,
|
||||
) -> BoardOnboardingSession:
|
||||
"""Store onboarding updates submitted by the gateway main agent."""
|
||||
if actor.actor_type != "agent" or actor.agent is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
agent = actor.agent
|
||||
@@ -288,9 +328,13 @@ async def agent_onboarding_update(
|
||||
|
||||
if board.gateway_id:
|
||||
gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
|
||||
if gateway and gateway.main_session_key and agent.openclaw_session_id:
|
||||
if agent.openclaw_session_id != gateway.main_session_key:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
if (
|
||||
gateway
|
||||
and gateway.main_session_key
|
||||
and agent.openclaw_session_id
|
||||
and agent.openclaw_session_id != gateway.main_session_key
|
||||
):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
|
||||
onboarding = (
|
||||
await BoardOnboardingSession.objects.filter_by(board_id=board.id)
|
||||
@@ -315,9 +359,13 @@ async def agent_onboarding_update(
|
||||
if isinstance(payload, BoardOnboardingAgentComplete):
|
||||
onboarding.draft_goal = payload_data
|
||||
onboarding.status = "completed"
|
||||
messages.append({"role": "assistant", "content": payload_text, "timestamp": now})
|
||||
messages.append(
|
||||
{"role": "assistant", "content": payload_text, "timestamp": now},
|
||||
)
|
||||
else:
|
||||
messages.append({"role": "assistant", "content": payload_text, "timestamp": now})
|
||||
messages.append(
|
||||
{"role": "assistant", "content": payload_text, "timestamp": now},
|
||||
)
|
||||
|
||||
onboarding.messages = messages
|
||||
onboarding.updated_at = utcnow()
|
||||
@@ -334,12 +382,13 @@ async def agent_onboarding_update(
|
||||
|
||||
|
||||
@router.post("/confirm", response_model=BoardRead)
|
||||
async def confirm_onboarding(
|
||||
async def confirm_onboarding( # noqa: C901, PLR0912, PLR0915
|
||||
payload: BoardOnboardingConfirm,
|
||||
board: Board = Depends(get_board_for_user_write),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
auth: AuthContext = Depends(require_admin_auth),
|
||||
board: Board = BOARD_USER_WRITE_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
auth: AuthContext = ADMIN_AUTH_DEP,
|
||||
) -> Board:
|
||||
"""Confirm onboarding results and provision the board lead agent."""
|
||||
onboarding = (
|
||||
await BoardOnboardingSession.objects.filter_by(board_id=board.id)
|
||||
.order_by(col(BoardOnboardingSession.updated_at).desc())
|
||||
@@ -409,7 +458,9 @@ async def confirm_onboarding(
|
||||
if lead_agent.update_cadence:
|
||||
lead_identity_profile["update_cadence"] = lead_agent.update_cadence
|
||||
if lead_agent.custom_instructions:
|
||||
lead_identity_profile["custom_instructions"] = lead_agent.custom_instructions
|
||||
lead_identity_profile["custom_instructions"] = (
|
||||
lead_agent.custom_instructions
|
||||
)
|
||||
|
||||
gateway, config = await _gateway_config(session, board)
|
||||
session.add(board)
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""Board CRUD and snapshot endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.api.deps import (
|
||||
get_board_for_actor_read,
|
||||
@@ -47,9 +49,23 @@ from app.services.board_group_snapshot import build_board_group_snapshot
|
||||
from app.services.board_snapshot import build_board_snapshot
|
||||
from app.services.organizations import OrganizationContext, board_access_filter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
router = APIRouter(prefix="/boards", tags=["boards"])
|
||||
|
||||
AGENT_SESSION_PREFIX = "agent"
|
||||
SESSION_DEP = Depends(get_session)
|
||||
ORG_ADMIN_DEP = Depends(require_org_admin)
|
||||
ORG_MEMBER_DEP = Depends(require_org_member)
|
||||
BOARD_USER_READ_DEP = Depends(get_board_for_user_read)
|
||||
BOARD_USER_WRITE_DEP = Depends(get_board_for_user_write)
|
||||
BOARD_ACTOR_READ_DEP = Depends(get_board_for_actor_read)
|
||||
GATEWAY_ID_QUERY = Query(default=None)
|
||||
BOARD_GROUP_ID_QUERY = Query(default=None)
|
||||
INCLUDE_SELF_QUERY = Query(default=False)
|
||||
INCLUDE_DONE_QUERY = Query(default=False)
|
||||
PER_BOARD_TASK_LIMIT_QUERY = Query(default=5, ge=0, le=100)
|
||||
|
||||
|
||||
def _slugify(value: str) -> str:
|
||||
@@ -83,10 +99,12 @@ async def _require_gateway(
|
||||
|
||||
async def _require_gateway_for_create(
|
||||
payload: BoardCreate,
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
) -> Gateway:
|
||||
return await _require_gateway(session, payload.gateway_id, organization_id=ctx.organization.id)
|
||||
return await _require_gateway(
|
||||
session, payload.gateway_id, organization_id=ctx.organization.id,
|
||||
)
|
||||
|
||||
|
||||
async def _require_board_group(
|
||||
@@ -111,8 +129,8 @@ async def _require_board_group(
|
||||
|
||||
async def _require_board_group_for_create(
|
||||
payload: BoardCreate,
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
) -> BoardGroup | None:
|
||||
if payload.board_group_id is None:
|
||||
return None
|
||||
@@ -123,6 +141,10 @@ async def _require_board_group_for_create(
|
||||
)
|
||||
|
||||
|
||||
GATEWAY_CREATE_DEP = Depends(_require_gateway_for_create)
|
||||
BOARD_GROUP_CREATE_DEP = Depends(_require_board_group_for_create)
|
||||
|
||||
|
||||
async def _apply_board_update(
|
||||
*,
|
||||
payload: BoardUpdate,
|
||||
@@ -132,7 +154,7 @@ async def _apply_board_update(
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
if "gateway_id" in updates:
|
||||
await _require_gateway(
|
||||
session, updates["gateway_id"], organization_id=board.organization_id
|
||||
session, updates["gateway_id"], organization_id=board.organization_id,
|
||||
)
|
||||
if "board_group_id" in updates and updates["board_group_id"] is not None:
|
||||
await _require_board_group(
|
||||
@@ -141,13 +163,15 @@ async def _apply_board_update(
|
||||
organization_id=board.organization_id,
|
||||
)
|
||||
crud.apply_updates(board, updates)
|
||||
if updates.get("board_type") == "goal":
|
||||
if (
|
||||
updates.get("board_type") == "goal"
|
||||
and (not board.objective or not board.success_metrics)
|
||||
):
|
||||
# Validate only when explicitly switching to goal boards.
|
||||
if not board.objective or not board.success_metrics:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Goal boards require objective and success_metrics",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Goal boards require objective and success_metrics",
|
||||
)
|
||||
if not board.gateway_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
@@ -158,7 +182,7 @@ async def _apply_board_update(
|
||||
|
||||
|
||||
async def _board_gateway(
|
||||
session: AsyncSession, board: Board
|
||||
session: AsyncSession, board: Board,
|
||||
) -> tuple[Gateway | None, GatewayClientConfig | None]:
|
||||
if not board.gateway_id:
|
||||
return None, None
|
||||
@@ -218,28 +242,32 @@ async def _cleanup_agent_on_gateway(
|
||||
|
||||
@router.get("", response_model=DefaultLimitOffsetPage[BoardRead])
|
||||
async def list_boards(
|
||||
gateway_id: UUID | None = Query(default=None),
|
||||
board_group_id: UUID | None = Query(default=None),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_member),
|
||||
gateway_id: UUID | None = GATEWAY_ID_QUERY,
|
||||
board_group_id: UUID | None = BOARD_GROUP_ID_QUERY,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_MEMBER_DEP,
|
||||
) -> DefaultLimitOffsetPage[BoardRead]:
|
||||
"""List boards visible to the current organization member."""
|
||||
statement = select(Board).where(board_access_filter(ctx.member, write=False))
|
||||
if gateway_id is not None:
|
||||
statement = statement.where(col(Board.gateway_id) == gateway_id)
|
||||
if board_group_id is not None:
|
||||
statement = statement.where(col(Board.board_group_id) == board_group_id)
|
||||
statement = statement.order_by(func.lower(col(Board.name)).asc(), col(Board.created_at).desc())
|
||||
statement = statement.order_by(
|
||||
func.lower(col(Board.name)).asc(), col(Board.created_at).desc(),
|
||||
)
|
||||
return await paginate(session, statement)
|
||||
|
||||
|
||||
@router.post("", response_model=BoardRead)
|
||||
async def create_board(
|
||||
payload: BoardCreate,
|
||||
_gateway: Gateway = Depends(_require_gateway_for_create),
|
||||
_board_group: BoardGroup | None = Depends(_require_board_group_for_create),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
_gateway: Gateway = GATEWAY_CREATE_DEP,
|
||||
_board_group: BoardGroup | None = BOARD_GROUP_CREATE_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> Board:
|
||||
"""Create a board in the active organization."""
|
||||
data = payload.model_dump()
|
||||
data["organization_id"] = ctx.organization.id
|
||||
return await crud.create(session, Board, **data)
|
||||
@@ -247,27 +275,31 @@ async def create_board(
|
||||
|
||||
@router.get("/{board_id}", response_model=BoardRead)
|
||||
def get_board(
|
||||
board: Board = Depends(get_board_for_user_read),
|
||||
board: Board = BOARD_USER_READ_DEP,
|
||||
) -> Board:
|
||||
"""Get a board by id."""
|
||||
return board
|
||||
|
||||
|
||||
@router.get("/{board_id}/snapshot", response_model=BoardSnapshot)
|
||||
async def get_board_snapshot(
|
||||
board: Board = Depends(get_board_for_actor_read),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
board: Board = BOARD_ACTOR_READ_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
) -> BoardSnapshot:
|
||||
"""Get a board snapshot view model."""
|
||||
return await build_board_snapshot(session, board)
|
||||
|
||||
|
||||
@router.get("/{board_id}/group-snapshot", response_model=BoardGroupSnapshot)
|
||||
async def get_board_group_snapshot(
|
||||
include_self: bool = Query(default=False),
|
||||
include_done: bool = Query(default=False),
|
||||
per_board_task_limit: int = Query(default=5, ge=0, le=100),
|
||||
board: Board = Depends(get_board_for_actor_read),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
*,
|
||||
include_self: bool = INCLUDE_SELF_QUERY,
|
||||
include_done: bool = INCLUDE_DONE_QUERY,
|
||||
per_board_task_limit: int = PER_BOARD_TASK_LIMIT_QUERY,
|
||||
board: Board = BOARD_ACTOR_READ_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
) -> BoardGroupSnapshot:
|
||||
"""Get a grouped snapshot across related boards."""
|
||||
return await build_board_group_snapshot(
|
||||
session,
|
||||
board=board,
|
||||
@@ -280,19 +312,23 @@ async def get_board_group_snapshot(
|
||||
@router.patch("/{board_id}", response_model=BoardRead)
|
||||
async def update_board(
|
||||
payload: BoardUpdate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
board: Board = Depends(get_board_for_user_write),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
board: Board = BOARD_USER_WRITE_DEP,
|
||||
) -> Board:
|
||||
"""Update mutable board properties."""
|
||||
return await _apply_board_update(payload=payload, session=session, board=board)
|
||||
|
||||
|
||||
@router.delete("/{board_id}", response_model=OkResponse)
|
||||
async def delete_board(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
board: Board = Depends(get_board_for_user_write),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
board: Board = BOARD_USER_WRITE_DEP,
|
||||
) -> OkResponse:
|
||||
"""Delete a board and all dependent records."""
|
||||
agents = await Agent.objects.filter_by(board_id=board.id).all(session)
|
||||
task_ids = list(await session.exec(select(Task.id).where(Task.board_id == board.id)))
|
||||
task_ids = list(
|
||||
await session.exec(select(Task.id).where(Task.board_id == board.id)),
|
||||
)
|
||||
|
||||
config, client_config = await _board_gateway(session, board)
|
||||
if config and client_config:
|
||||
@@ -307,20 +343,31 @@ async def delete_board(
|
||||
|
||||
if task_ids:
|
||||
await crud.delete_where(
|
||||
session, ActivityEvent, col(ActivityEvent.task_id).in_(task_ids), commit=False
|
||||
session,
|
||||
ActivityEvent,
|
||||
col(ActivityEvent.task_id).in_(task_ids),
|
||||
commit=False,
|
||||
)
|
||||
await crud.delete_where(session, TaskDependency, col(TaskDependency.board_id) == board.id)
|
||||
await crud.delete_where(session, TaskFingerprint, col(TaskFingerprint.board_id) == board.id)
|
||||
await crud.delete_where(
|
||||
session, TaskDependency, col(TaskDependency.board_id) == board.id,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session, TaskFingerprint, col(TaskFingerprint.board_id) == board.id,
|
||||
)
|
||||
|
||||
# Approvals can reference tasks and agents, so delete before both.
|
||||
await crud.delete_where(session, Approval, col(Approval.board_id) == board.id)
|
||||
|
||||
await crud.delete_where(session, BoardMemory, col(BoardMemory.board_id) == board.id)
|
||||
await crud.delete_where(
|
||||
session, BoardOnboardingSession, col(BoardOnboardingSession.board_id) == board.id
|
||||
session,
|
||||
BoardOnboardingSession,
|
||||
col(BoardOnboardingSession.board_id) == board.id,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session, OrganizationBoardAccess, col(OrganizationBoardAccess.board_id) == board.id
|
||||
session,
|
||||
OrganizationBoardAccess,
|
||||
col(OrganizationBoardAccess.board_id) == board.id,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session,
|
||||
@@ -328,14 +375,17 @@ async def delete_board(
|
||||
col(OrganizationInviteBoardAccess.board_id) == board.id,
|
||||
)
|
||||
|
||||
# Tasks reference agents (assigned_agent_id) and have dependents (fingerprints/dependencies), so
|
||||
# Tasks reference agents and have dependent records.
|
||||
# delete tasks before agents.
|
||||
await crud.delete_where(session, Task, col(Task.board_id) == board.id)
|
||||
|
||||
if agents:
|
||||
agent_ids = [agent.id for agent in agents]
|
||||
await crud.delete_where(
|
||||
session, ActivityEvent, col(ActivityEvent.agent_id).in_(agent_ids), commit=False
|
||||
session,
|
||||
ActivityEvent,
|
||||
col(ActivityEvent.agent_id).in_(agent_ids),
|
||||
commit=False,
|
||||
)
|
||||
await crud.delete_where(session, Agent, col(Agent.id).in_(agent_ids))
|
||||
await session.delete(board)
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
"""Organization management endpoints and membership/invite flows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
from typing import Any, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.api.deps import require_org_admin, require_org_member
|
||||
from app.core.auth import AuthContext, get_auth_context
|
||||
from app.core.auth import get_auth_context
|
||||
from app.core.time import utcnow
|
||||
from app.db import crud
|
||||
from app.db.pagination import paginate
|
||||
@@ -63,10 +64,21 @@ from app.services.organizations import (
|
||||
set_active_organization,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.auth import AuthContext
|
||||
|
||||
router = APIRouter(prefix="/organizations", tags=["organizations"])
|
||||
SESSION_DEP = Depends(get_session)
|
||||
AUTH_DEP = Depends(get_auth_context)
|
||||
ORG_MEMBER_DEP = Depends(require_org_member)
|
||||
ORG_ADMIN_DEP = Depends(require_org_admin)
|
||||
|
||||
|
||||
def _member_to_read(member: OrganizationMember, user: User | None) -> OrganizationMemberRead:
|
||||
def _member_to_read(
|
||||
member: OrganizationMember, user: User | None,
|
||||
) -> OrganizationMemberRead:
|
||||
model = OrganizationMemberRead.model_validate(member, from_attributes=True)
|
||||
if user is not None:
|
||||
model.user = OrganizationUserRead.model_validate(user, from_attributes=True)
|
||||
@@ -100,9 +112,10 @@ async def _require_org_invite(
|
||||
@router.post("", response_model=OrganizationRead)
|
||||
async def create_organization(
|
||||
payload: OrganizationCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
auth: AuthContext = AUTH_DEP,
|
||||
) -> OrganizationRead:
|
||||
"""Create an organization and assign the caller as owner."""
|
||||
if auth.user is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
name = payload.name.strip()
|
||||
@@ -110,7 +123,9 @@ async def create_organization(
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
existing = (
|
||||
await session.exec(
|
||||
select(Organization).where(func.lower(col(Organization.name)) == name.lower())
|
||||
select(Organization).where(
|
||||
func.lower(col(Organization.name)) == name.lower(),
|
||||
),
|
||||
)
|
||||
).first()
|
||||
if existing is not None:
|
||||
@@ -140,19 +155,25 @@ async def create_organization(
|
||||
|
||||
@router.get("/me/list", response_model=list[OrganizationListItem])
|
||||
async def list_my_organizations(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
auth: AuthContext = AUTH_DEP,
|
||||
) -> list[OrganizationListItem]:
|
||||
"""List organizations where the current user is a member."""
|
||||
if auth.user is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
await get_active_membership(session, auth.user)
|
||||
db_user = await User.objects.by_id(auth.user.id).first(session)
|
||||
active_id = db_user.active_organization_id if db_user else auth.user.active_organization_id
|
||||
active_id = (
|
||||
db_user.active_organization_id if db_user else auth.user.active_organization_id
|
||||
)
|
||||
|
||||
statement = (
|
||||
select(Organization, OrganizationMember)
|
||||
.join(OrganizationMember, col(OrganizationMember.organization_id) == col(Organization.id))
|
||||
.join(
|
||||
OrganizationMember,
|
||||
col(OrganizationMember.organization_id) == col(Organization.id),
|
||||
)
|
||||
.where(col(OrganizationMember.user_id) == auth.user.id)
|
||||
.order_by(func.lower(col(Organization.name)).asc())
|
||||
)
|
||||
@@ -171,30 +192,37 @@ async def list_my_organizations(
|
||||
@router.patch("/me/active", response_model=OrganizationRead)
|
||||
async def set_active_org(
|
||||
payload: OrganizationActiveUpdate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
auth: AuthContext = AUTH_DEP,
|
||||
) -> OrganizationRead:
|
||||
"""Set the caller's active organization."""
|
||||
if auth.user is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
member = await set_active_organization(
|
||||
session, user=auth.user, organization_id=payload.organization_id
|
||||
session, user=auth.user, organization_id=payload.organization_id,
|
||||
)
|
||||
organization = await Organization.objects.by_id(member.organization_id).first(
|
||||
session,
|
||||
)
|
||||
organization = await Organization.objects.by_id(member.organization_id).first(session)
|
||||
if organization is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
return OrganizationRead.model_validate(organization, from_attributes=True)
|
||||
|
||||
|
||||
@router.get("/me", response_model=OrganizationRead)
|
||||
async def get_my_org(ctx: OrganizationContext = Depends(require_org_member)) -> OrganizationRead:
|
||||
async def get_my_org(
|
||||
ctx: OrganizationContext = ORG_MEMBER_DEP,
|
||||
) -> OrganizationRead:
|
||||
"""Return the caller's active organization."""
|
||||
return OrganizationRead.model_validate(ctx.organization, from_attributes=True)
|
||||
|
||||
|
||||
@router.delete("/me", response_model=OkResponse)
|
||||
async def delete_my_org(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> OkResponse:
|
||||
"""Delete the active organization and related entities."""
|
||||
if ctx.member.role != "owner":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
@@ -206,28 +234,39 @@ async def delete_my_org(
|
||||
task_ids = select(Task.id).where(col(Task.board_id).in_(board_ids))
|
||||
agent_ids = select(Agent.id).where(col(Agent.board_id).in_(board_ids))
|
||||
member_ids = select(OrganizationMember.id).where(
|
||||
col(OrganizationMember.organization_id) == org_id
|
||||
col(OrganizationMember.organization_id) == org_id,
|
||||
)
|
||||
invite_ids = select(OrganizationInvite.id).where(
|
||||
col(OrganizationInvite.organization_id) == org_id
|
||||
col(OrganizationInvite.organization_id) == org_id,
|
||||
)
|
||||
group_ids = select(BoardGroup.id).where(col(BoardGroup.organization_id) == org_id)
|
||||
|
||||
await crud.delete_where(
|
||||
session, ActivityEvent, col(ActivityEvent.task_id).in_(task_ids), commit=False
|
||||
session, ActivityEvent, col(ActivityEvent.task_id).in_(task_ids), commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session, ActivityEvent, col(ActivityEvent.agent_id).in_(agent_ids), commit=False
|
||||
session,
|
||||
ActivityEvent,
|
||||
col(ActivityEvent.agent_id).in_(agent_ids),
|
||||
commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session, TaskDependency, col(TaskDependency.board_id).in_(board_ids), commit=False
|
||||
session,
|
||||
TaskDependency,
|
||||
col(TaskDependency.board_id).in_(board_ids),
|
||||
commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session, TaskFingerprint, col(TaskFingerprint.board_id).in_(board_ids), commit=False
|
||||
session,
|
||||
TaskFingerprint,
|
||||
col(TaskFingerprint.board_id).in_(board_ids),
|
||||
commit=False,
|
||||
)
|
||||
await crud.delete_where(session, Approval, col(Approval.board_id).in_(board_ids), commit=False)
|
||||
await crud.delete_where(
|
||||
session, BoardMemory, col(BoardMemory.board_id).in_(board_ids), commit=False
|
||||
session, Approval, col(Approval.board_id).in_(board_ids), commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session, BoardMemory, col(BoardMemory.board_id).in_(board_ids), commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session,
|
||||
@@ -259,9 +298,15 @@ async def delete_my_org(
|
||||
col(OrganizationInviteBoardAccess.organization_invite_id).in_(invite_ids),
|
||||
commit=False,
|
||||
)
|
||||
await crud.delete_where(session, Task, col(Task.board_id).in_(board_ids), commit=False)
|
||||
await crud.delete_where(session, Agent, col(Agent.board_id).in_(board_ids), commit=False)
|
||||
await crud.delete_where(session, Board, col(Board.organization_id) == org_id, commit=False)
|
||||
await crud.delete_where(
|
||||
session, Task, col(Task.board_id).in_(board_ids), commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session, Agent, col(Agent.board_id).in_(board_ids), commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session, Board, col(Board.organization_id) == org_id, commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session,
|
||||
BoardGroupMemory,
|
||||
@@ -269,9 +314,11 @@ async def delete_my_org(
|
||||
commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session, BoardGroup, col(BoardGroup.organization_id) == org_id, commit=False
|
||||
session, BoardGroup, col(BoardGroup.organization_id) == org_id, commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session, Gateway, col(Gateway.organization_id) == org_id, commit=False,
|
||||
)
|
||||
await crud.delete_where(session, Gateway, col(Gateway.organization_id) == org_id, commit=False)
|
||||
await crud.delete_where(
|
||||
session,
|
||||
OrganizationInvite,
|
||||
@@ -291,32 +338,39 @@ async def delete_my_org(
|
||||
active_organization_id=None,
|
||||
commit=False,
|
||||
)
|
||||
await crud.delete_where(session, Organization, col(Organization.id) == org_id, commit=False)
|
||||
await crud.delete_where(
|
||||
session, Organization, col(Organization.id) == org_id, commit=False,
|
||||
)
|
||||
await session.commit()
|
||||
return OkResponse()
|
||||
|
||||
|
||||
@router.get("/me/member", response_model=OrganizationMemberRead)
|
||||
async def get_my_membership(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_member),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_MEMBER_DEP,
|
||||
) -> OrganizationMemberRead:
|
||||
"""Get the caller's membership record in the active organization."""
|
||||
user = await User.objects.by_id(ctx.member.user_id).first(session)
|
||||
access_rows = await OrganizationBoardAccess.objects.filter_by(
|
||||
organization_member_id=ctx.member.id
|
||||
organization_member_id=ctx.member.id,
|
||||
).all(session)
|
||||
model = _member_to_read(ctx.member, user)
|
||||
model.board_access = [
|
||||
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows
|
||||
OrganizationBoardAccessRead.model_validate(row, from_attributes=True)
|
||||
for row in access_rows
|
||||
]
|
||||
return model
|
||||
|
||||
|
||||
@router.get("/me/members", response_model=DefaultLimitOffsetPage[OrganizationMemberRead])
|
||||
@router.get(
|
||||
"/me/members", response_model=DefaultLimitOffsetPage[OrganizationMemberRead],
|
||||
)
|
||||
async def list_org_members(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_member),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_MEMBER_DEP,
|
||||
) -> DefaultLimitOffsetPage[OrganizationMemberRead]:
|
||||
"""List members for the active organization."""
|
||||
statement = (
|
||||
select(OrganizationMember, User)
|
||||
.join(User, col(User.id) == col(OrganizationMember.user_id))
|
||||
@@ -336,9 +390,10 @@ async def list_org_members(
|
||||
@router.get("/me/members/{member_id}", response_model=OrganizationMemberRead)
|
||||
async def get_org_member(
|
||||
member_id: UUID,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_member),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_MEMBER_DEP,
|
||||
) -> OrganizationMemberRead:
|
||||
"""Get a specific organization member by id."""
|
||||
member = await _require_org_member(
|
||||
session,
|
||||
organization_id=ctx.organization.id,
|
||||
@@ -348,11 +403,12 @@ async def get_org_member(
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
user = await User.objects.by_id(member.user_id).first(session)
|
||||
access_rows = await OrganizationBoardAccess.objects.filter_by(
|
||||
organization_member_id=member.id
|
||||
organization_member_id=member.id,
|
||||
).all(session)
|
||||
model = _member_to_read(member, user)
|
||||
model.board_access = [
|
||||
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows
|
||||
OrganizationBoardAccessRead.model_validate(row, from_attributes=True)
|
||||
for row in access_rows
|
||||
]
|
||||
return model
|
||||
|
||||
@@ -361,9 +417,10 @@ async def get_org_member(
|
||||
async def update_org_member(
|
||||
member_id: UUID,
|
||||
payload: OrganizationMemberUpdate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> OrganizationMemberRead:
|
||||
"""Update a member's role in the organization."""
|
||||
member = await _require_org_member(
|
||||
session,
|
||||
organization_id=ctx.organization.id,
|
||||
@@ -382,9 +439,10 @@ async def update_org_member(
|
||||
async def update_member_access(
|
||||
member_id: UUID,
|
||||
payload: OrganizationMemberAccessUpdate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> OrganizationMemberRead:
|
||||
"""Update board-level access settings for a member."""
|
||||
member = await _require_org_member(
|
||||
session,
|
||||
organization_id=ctx.organization.id,
|
||||
@@ -395,7 +453,9 @@ async def update_member_access(
|
||||
if board_ids:
|
||||
valid_board_ids = {
|
||||
board.id
|
||||
for board in await Board.objects.filter_by(organization_id=ctx.organization.id)
|
||||
for board in await Board.objects.filter_by(
|
||||
organization_id=ctx.organization.id,
|
||||
)
|
||||
.filter(col(Board.id).in_(board_ids))
|
||||
.all(session)
|
||||
}
|
||||
@@ -412,9 +472,10 @@ async def update_member_access(
|
||||
@router.delete("/me/members/{member_id}", response_model=OkResponse)
|
||||
async def remove_org_member(
|
||||
member_id: UUID,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> OkResponse:
|
||||
"""Remove a member from the active organization."""
|
||||
member = await _require_org_member(
|
||||
session,
|
||||
organization_id=ctx.organization.id,
|
||||
@@ -432,7 +493,9 @@ async def remove_org_member(
|
||||
)
|
||||
if member.role == "owner":
|
||||
owners = (
|
||||
await OrganizationMember.objects.filter_by(organization_id=ctx.organization.id)
|
||||
await OrganizationMember.objects.filter_by(
|
||||
organization_id=ctx.organization.id,
|
||||
)
|
||||
.filter(col(OrganizationMember.role) == "owner")
|
||||
.all(session)
|
||||
)
|
||||
@@ -463,7 +526,9 @@ async def remove_org_member(
|
||||
user.active_organization_id = fallback_membership
|
||||
else:
|
||||
user.active_organization_id = (
|
||||
fallback_membership.organization_id if fallback_membership is not None else None
|
||||
fallback_membership.organization_id
|
||||
if fallback_membership is not None
|
||||
else None
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
@@ -471,11 +536,14 @@ async def remove_org_member(
|
||||
return OkResponse()
|
||||
|
||||
|
||||
@router.get("/me/invites", response_model=DefaultLimitOffsetPage[OrganizationInviteRead])
|
||||
@router.get(
|
||||
"/me/invites", response_model=DefaultLimitOffsetPage[OrganizationInviteRead],
|
||||
)
|
||||
async def list_org_invites(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> DefaultLimitOffsetPage[OrganizationInviteRead]:
|
||||
"""List pending invites for the active organization."""
|
||||
statement = (
|
||||
OrganizationInvite.objects.filter_by(organization_id=ctx.organization.id)
|
||||
.filter(col(OrganizationInvite.accepted_at).is_(None))
|
||||
@@ -488,9 +556,10 @@ async def list_org_invites(
|
||||
@router.post("/me/invites", response_model=OrganizationInviteRead)
|
||||
async def create_org_invite(
|
||||
payload: OrganizationInviteCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> OrganizationInviteRead:
|
||||
"""Create an organization invite for an email address."""
|
||||
email = normalize_invited_email(payload.invited_email)
|
||||
if not email:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
@@ -526,13 +595,17 @@ async def create_org_invite(
|
||||
if board_ids:
|
||||
valid_board_ids = {
|
||||
board.id
|
||||
for board in await Board.objects.filter_by(organization_id=ctx.organization.id)
|
||||
for board in await Board.objects.filter_by(
|
||||
organization_id=ctx.organization.id,
|
||||
)
|
||||
.filter(col(Board.id).in_(board_ids))
|
||||
.all(session)
|
||||
}
|
||||
if valid_board_ids != board_ids:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
await apply_invite_board_access(session, invite=invite, entries=payload.board_access)
|
||||
await apply_invite_board_access(
|
||||
session, invite=invite, entries=payload.board_access,
|
||||
)
|
||||
await session.commit()
|
||||
await session.refresh(invite)
|
||||
return OrganizationInviteRead.model_validate(invite, from_attributes=True)
|
||||
@@ -541,9 +614,10 @@ async def create_org_invite(
|
||||
@router.delete("/me/invites/{invite_id}", response_model=OrganizationInviteRead)
|
||||
async def revoke_org_invite(
|
||||
invite_id: UUID,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> OrganizationInviteRead:
|
||||
"""Revoke a pending invite from the active organization."""
|
||||
invite = await _require_org_invite(
|
||||
session,
|
||||
organization_id=ctx.organization.id,
|
||||
@@ -562,9 +636,10 @@ async def revoke_org_invite(
|
||||
@router.post("/invites/accept", response_model=OrganizationMemberRead)
|
||||
async def accept_org_invite(
|
||||
payload: OrganizationInviteAccept,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
auth: AuthContext = AUTH_DEP,
|
||||
) -> OrganizationMemberRead:
|
||||
"""Accept an invite and return resulting membership."""
|
||||
if auth.user is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
invite = await OrganizationInvite.objects.filter(
|
||||
@@ -573,11 +648,13 @@ async def accept_org_invite(
|
||||
).first(session)
|
||||
if invite is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
if invite.invited_email and auth.user.email:
|
||||
if normalize_invited_email(invite.invited_email) != normalize_invited_email(
|
||||
auth.user.email
|
||||
):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
if (
|
||||
invite.invited_email
|
||||
and auth.user.email
|
||||
and normalize_invited_email(invite.invited_email)
|
||||
!= normalize_invited_email(auth.user.email)
|
||||
):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
|
||||
existing = await get_member(
|
||||
session,
|
||||
|
||||
@@ -1,41 +1,54 @@
|
||||
"""API-level thin wrapper around query-set helpers with HTTP conveniences."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
from app.db.queryset import QuerySet, qs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
ModelT = TypeVar("ModelT")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class APIQuerySet(Generic[ModelT]):
|
||||
"""Immutable query-set wrapper tailored for API-layer usage."""
|
||||
|
||||
queryset: QuerySet[ModelT]
|
||||
|
||||
@property
|
||||
def statement(self) -> SelectOfScalar[ModelT]:
|
||||
"""Expose the underlying SQL statement for advanced composition."""
|
||||
return self.queryset.statement
|
||||
|
||||
def filter(self, *criteria: Any) -> APIQuerySet[ModelT]:
|
||||
def filter(self, *criteria: object) -> APIQuerySet[ModelT]:
|
||||
"""Return a new queryset with additional SQL criteria applied."""
|
||||
return APIQuerySet(self.queryset.filter(*criteria))
|
||||
|
||||
def order_by(self, *ordering: Any) -> APIQuerySet[ModelT]:
|
||||
def order_by(self, *ordering: object) -> APIQuerySet[ModelT]:
|
||||
"""Return a new queryset with ordering clauses applied."""
|
||||
return APIQuerySet(self.queryset.order_by(*ordering))
|
||||
|
||||
def limit(self, value: int) -> APIQuerySet[ModelT]:
|
||||
"""Return a new queryset with a row limit applied."""
|
||||
return APIQuerySet(self.queryset.limit(value))
|
||||
|
||||
def offset(self, value: int) -> APIQuerySet[ModelT]:
|
||||
"""Return a new queryset with an offset applied."""
|
||||
return APIQuerySet(self.queryset.offset(value))
|
||||
|
||||
async def all(self, session: AsyncSession) -> list[ModelT]:
|
||||
"""Fetch all rows for the current queryset."""
|
||||
return await self.queryset.all(session)
|
||||
|
||||
async def first(self, session: AsyncSession) -> ModelT | None:
|
||||
"""Fetch the first row for the current queryset, if present."""
|
||||
return await self.queryset.first(session)
|
||||
|
||||
async def first_or_404(
|
||||
@@ -44,6 +57,7 @@ class APIQuerySet(Generic[ModelT]):
|
||||
*,
|
||||
detail: str | None = None,
|
||||
) -> ModelT:
|
||||
"""Fetch the first row or raise HTTP 404 when no row exists."""
|
||||
obj = await self.first(session)
|
||||
if obj is not None:
|
||||
return obj
|
||||
@@ -53,4 +67,5 @@ class APIQuerySet(Generic[ModelT]):
|
||||
|
||||
|
||||
def api_qs(model: type[ModelT]) -> APIQuerySet[ModelT]:
|
||||
"""Create an APIQuerySet for a SQLModel class."""
|
||||
return APIQuerySet(qs(model))
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""API routes for searching and fetching souls-directory markdown entries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
@@ -13,6 +15,7 @@ from app.schemas.souls_directory import (
|
||||
from app.services import souls_directory
|
||||
|
||||
router = APIRouter(prefix="/souls-directory", tags=["souls-directory"])
|
||||
ADMIN_OR_AGENT_DEP = Depends(require_admin_or_agent)
|
||||
|
||||
_SAFE_SEGMENT_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$")
|
||||
_SAFE_SLUG_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$")
|
||||
@@ -41,8 +44,9 @@ def _validate_segment(value: str, *, field: str) -> str:
|
||||
async def search(
|
||||
q: str = Query(default="", min_length=0),
|
||||
limit: int = Query(default=20, ge=1, le=100),
|
||||
_actor: ActorContext = Depends(require_admin_or_agent),
|
||||
_actor: ActorContext = ADMIN_OR_AGENT_DEP,
|
||||
) -> SoulsDirectorySearchResponse:
|
||||
"""Search souls-directory entries by handle/slug query text."""
|
||||
refs = await souls_directory.list_souls_directory_refs()
|
||||
matches = souls_directory.search_souls(refs, query=q, limit=limit)
|
||||
items = [
|
||||
@@ -62,12 +66,23 @@ async def search(
|
||||
async def get_markdown(
|
||||
handle: str,
|
||||
slug: str,
|
||||
_actor: ActorContext = Depends(require_admin_or_agent),
|
||||
_actor: ActorContext = ADMIN_OR_AGENT_DEP,
|
||||
) -> SoulsDirectoryMarkdownResponse:
|
||||
"""Fetch markdown content for a validated souls-directory handle and slug."""
|
||||
safe_handle = _validate_segment(handle, field="handle")
|
||||
safe_slug = _validate_segment(slug.removesuffix(".md"), field="slug")
|
||||
try:
|
||||
content = await souls_directory.fetch_soul_markdown(handle=safe_handle, slug=safe_slug)
|
||||
content = await souls_directory.fetch_soul_markdown(
|
||||
handle=safe_handle,
|
||||
slug=safe_slug,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
|
||||
return SoulsDirectoryMarkdownResponse(handle=safe_handle, slug=safe_slug, content=content)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
return SoulsDirectoryMarkdownResponse(
|
||||
handle=safe_handle,
|
||||
slug=safe_slug,
|
||||
content=content,
|
||||
)
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
"""Task API routes for listing, streaming, and mutating board tasks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections import deque
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
from contextlib import suppress
|
||||
from datetime import datetime, timezone
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from sqlalchemy import asc, desc, or_
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import Select
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
@@ -23,13 +25,16 @@ from app.api.deps import (
|
||||
require_admin_auth,
|
||||
require_admin_or_agent,
|
||||
)
|
||||
from app.core.auth import AuthContext
|
||||
from app.core.time import utcnow
|
||||
from app.db import crud
|
||||
from app.db.pagination import paginate
|
||||
from app.db.session import async_session_maker, get_session
|
||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
|
||||
from app.integrations.openclaw_gateway import (
|
||||
OpenClawGatewayError,
|
||||
ensure_session,
|
||||
send_message,
|
||||
)
|
||||
from app.models.activity_events import ActivityEvent
|
||||
from app.models.agents import Agent
|
||||
from app.models.approvals import Approval
|
||||
@@ -41,7 +46,13 @@ from app.models.tasks import Task
|
||||
from app.schemas.common import OkResponse
|
||||
from app.schemas.errors import BlockedTaskError
|
||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||
from app.schemas.tasks import TaskCommentCreate, TaskCommentRead, TaskCreate, TaskRead, TaskUpdate
|
||||
from app.schemas.tasks import (
|
||||
TaskCommentCreate,
|
||||
TaskCommentRead,
|
||||
TaskCreate,
|
||||
TaskRead,
|
||||
TaskUpdate,
|
||||
)
|
||||
from app.services.activity_log import record_activity
|
||||
from app.services.mentions import extract_mentions, matches_agent_mention
|
||||
from app.services.organizations import require_board_access
|
||||
@@ -54,6 +65,11 @@ from app.services.task_dependencies import (
|
||||
validate_dependency_update,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.auth import AuthContext
|
||||
|
||||
router = APIRouter(prefix="/boards/{board_id}/tasks", tags=["tasks"])
|
||||
|
||||
ALLOWED_STATUSES = {"inbox", "in_progress", "review", "done"}
|
||||
@@ -66,6 +82,14 @@ TASK_EVENT_TYPES = {
|
||||
SSE_SEEN_MAX = 2000
|
||||
TASK_SNIPPET_MAX_LEN = 500
|
||||
TASK_SNIPPET_TRUNCATED_LEN = 497
|
||||
BOARD_READ_DEP = Depends(get_board_for_actor_read)
|
||||
ACTOR_DEP = Depends(require_admin_or_agent)
|
||||
SINCE_QUERY = Query(default=None)
|
||||
STATUS_QUERY = Query(default=None, alias="status")
|
||||
BOARD_WRITE_DEP = Depends(get_board_for_user_write)
|
||||
SESSION_DEP = Depends(get_session)
|
||||
ADMIN_AUTH_DEP = Depends(require_admin_auth)
|
||||
TASK_DEP = Depends(get_task_or_404)
|
||||
|
||||
|
||||
def _comment_validation_error() -> HTTPException:
|
||||
@@ -98,6 +122,7 @@ async def has_valid_recent_comment(
|
||||
agent_id: UUID | None,
|
||||
since: datetime | None,
|
||||
) -> bool:
|
||||
"""Check whether the task has a recent non-empty comment by the agent."""
|
||||
if agent_id is None or since is None:
|
||||
return False
|
||||
statement = (
|
||||
@@ -180,8 +205,8 @@ async def _reconcile_dependents_for_dependency_toggle(
|
||||
await session.exec(
|
||||
select(Task)
|
||||
.where(col(Task.board_id) == board_id)
|
||||
.where(col(Task.id).in_(dependent_ids))
|
||||
)
|
||||
.where(col(Task.id).in_(dependent_ids)),
|
||||
),
|
||||
)
|
||||
reopened = previous_status == "done" and dependency_task.status != "done"
|
||||
|
||||
@@ -204,7 +229,10 @@ async def _reconcile_dependents_for_dependency_toggle(
|
||||
session,
|
||||
event_type="task.status_changed",
|
||||
task_id=dependent.id,
|
||||
message=f"Task returned to inbox: dependency reopened ({dependency_task.title}).",
|
||||
message=(
|
||||
"Task returned to inbox: dependency reopened "
|
||||
f"({dependency_task.title})."
|
||||
),
|
||||
agent_id=actor_agent_id,
|
||||
)
|
||||
else:
|
||||
@@ -230,7 +258,9 @@ async def _fetch_task_events(
|
||||
board_id: UUID,
|
||||
since: datetime,
|
||||
) -> list[tuple[ActivityEvent, Task | None]]:
|
||||
task_ids = list(await session.exec(select(Task.id).where(col(Task.board_id) == board_id)))
|
||||
task_ids = list(
|
||||
await session.exec(select(Task.id).where(col(Task.board_id) == board_id)),
|
||||
)
|
||||
if not task_ids:
|
||||
return []
|
||||
statement = cast(
|
||||
@@ -249,7 +279,9 @@ def _serialize_comment(event: ActivityEvent) -> dict[str, object]:
|
||||
return TaskCommentRead.model_validate(event).model_dump(mode="json")
|
||||
|
||||
|
||||
async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None:
|
||||
async def _gateway_config(
|
||||
session: AsyncSession, board: Board,
|
||||
) -> GatewayClientConfig | None:
|
||||
if not board.gateway_id:
|
||||
return None
|
||||
gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
|
||||
@@ -303,7 +335,10 @@ async def _notify_agent_on_task_assign(
|
||||
message = (
|
||||
"TASK ASSIGNED\n"
|
||||
+ "\n".join(details)
|
||||
+ "\n\nTake action: open the task and begin work. Post updates as task comments."
|
||||
+ (
|
||||
"\n\nTake action: open the task and begin work. "
|
||||
"Post updates as task comments."
|
||||
)
|
||||
)
|
||||
try:
|
||||
await _send_agent_task_message(
|
||||
@@ -442,17 +477,18 @@ async def _notify_lead_on_task_unassigned(
|
||||
|
||||
|
||||
@router.get("/stream")
|
||||
async def stream_tasks(
|
||||
async def stream_tasks( # noqa: C901
|
||||
request: Request,
|
||||
board: Board = Depends(get_board_for_actor_read),
|
||||
actor: ActorContext = Depends(require_admin_or_agent),
|
||||
since: str | None = Query(default=None),
|
||||
board: Board = BOARD_READ_DEP,
|
||||
_actor: ActorContext = ACTOR_DEP,
|
||||
since: str | None = SINCE_QUERY,
|
||||
) -> EventSourceResponse:
|
||||
"""Stream task and task-comment events as SSE payloads."""
|
||||
since_dt = _parse_since(since) or utcnow()
|
||||
seen_ids: set[UUID] = set()
|
||||
seen_queue: deque[UUID] = deque()
|
||||
|
||||
async def event_generator() -> AsyncIterator[dict[str, str]]:
|
||||
async def event_generator() -> AsyncIterator[dict[str, str]]: # noqa: C901
|
||||
last_seen = since_dt
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
@@ -510,7 +546,7 @@ async def stream_tasks(
|
||||
"depends_on_task_ids": dep_list,
|
||||
"blocked_by_task_ids": blocked_by,
|
||||
"is_blocked": bool(blocked_by),
|
||||
}
|
||||
},
|
||||
)
|
||||
.model_dump(mode="json")
|
||||
)
|
||||
@@ -521,14 +557,15 @@ async def stream_tasks(
|
||||
|
||||
|
||||
@router.get("", response_model=DefaultLimitOffsetPage[TaskRead])
|
||||
async def list_tasks(
|
||||
status_filter: str | None = Query(default=None, alias="status"),
|
||||
async def list_tasks( # noqa: C901
|
||||
status_filter: str | None = STATUS_QUERY,
|
||||
assigned_agent_id: UUID | None = None,
|
||||
unassigned: bool | None = None,
|
||||
board: Board = Depends(get_board_for_actor_read),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
actor: ActorContext = Depends(require_admin_or_agent),
|
||||
board: Board = BOARD_READ_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
_actor: ActorContext = ACTOR_DEP,
|
||||
) -> DefaultLimitOffsetPage[TaskRead]:
|
||||
"""List board tasks with optional status and assignment filters."""
|
||||
statement = select(Task).where(Task.board_id == board.id)
|
||||
if status_filter:
|
||||
statuses = [s.strip() for s in status_filter.split(",") if s.strip()]
|
||||
@@ -550,7 +587,9 @@ async def list_tasks(
|
||||
if not tasks:
|
||||
return []
|
||||
task_ids = [task.id for task in tasks]
|
||||
deps_map = await dependency_ids_by_task_id(session, board_id=board.id, task_ids=task_ids)
|
||||
deps_map = await dependency_ids_by_task_id(
|
||||
session, board_id=board.id, task_ids=task_ids,
|
||||
)
|
||||
dep_ids: list[UUID] = []
|
||||
for value in deps_map.values():
|
||||
dep_ids.extend(value)
|
||||
@@ -563,7 +602,9 @@ async def list_tasks(
|
||||
output: list[TaskRead] = []
|
||||
for task in tasks:
|
||||
dep_list = deps_map.get(task.id, [])
|
||||
blocked_by = blocked_by_dependency_ids(dependency_ids=dep_list, status_by_id=dep_status)
|
||||
blocked_by = blocked_by_dependency_ids(
|
||||
dependency_ids=dep_list, status_by_id=dep_status,
|
||||
)
|
||||
if task.status == "done":
|
||||
blocked_by = []
|
||||
output.append(
|
||||
@@ -572,8 +613,8 @@ async def list_tasks(
|
||||
"depends_on_task_ids": dep_list,
|
||||
"blocked_by_task_ids": blocked_by,
|
||||
"is_blocked": bool(blocked_by),
|
||||
}
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -583,10 +624,11 @@ async def list_tasks(
|
||||
@router.post("", response_model=TaskRead, responses={409: {"model": BlockedTaskError}})
|
||||
async def create_task(
|
||||
payload: TaskCreate,
|
||||
board: Board = Depends(get_board_for_user_write),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
auth: AuthContext = Depends(require_admin_auth),
|
||||
board: Board = BOARD_WRITE_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
auth: AuthContext = ADMIN_AUTH_DEP,
|
||||
) -> TaskRead:
|
||||
"""Create a task and initialize dependency rows."""
|
||||
data = payload.model_dump()
|
||||
depends_on_task_ids = cast(list[UUID], data.pop("depends_on_task_ids", []) or [])
|
||||
|
||||
@@ -606,7 +648,9 @@ async def create_task(
|
||||
board_id=board.id,
|
||||
dependency_ids=normalized_deps,
|
||||
)
|
||||
blocked_by = blocked_by_dependency_ids(dependency_ids=normalized_deps, status_by_id=dep_status)
|
||||
blocked_by = blocked_by_dependency_ids(
|
||||
dependency_ids=normalized_deps, status_by_id=dep_status,
|
||||
)
|
||||
if blocked_by and (task.assigned_agent_id is not None or task.status != "inbox"):
|
||||
raise _blocked_task_error(blocked_by)
|
||||
session.add(task)
|
||||
@@ -618,7 +662,7 @@ async def create_task(
|
||||
board_id=board.id,
|
||||
task_id=task.id,
|
||||
depends_on_task_id=dep_id,
|
||||
)
|
||||
),
|
||||
)
|
||||
await session.commit()
|
||||
await session.refresh(task)
|
||||
@@ -632,7 +676,9 @@ async def create_task(
|
||||
await session.commit()
|
||||
await _notify_lead_on_task_create(session=session, board=board, task=task)
|
||||
if task.assigned_agent_id:
|
||||
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
|
||||
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(
|
||||
session,
|
||||
)
|
||||
if assigned_agent:
|
||||
await _notify_agent_on_task_assign(
|
||||
session=session,
|
||||
@@ -645,7 +691,7 @@ async def create_task(
|
||||
"depends_on_task_ids": normalized_deps,
|
||||
"blocked_by_task_ids": blocked_by,
|
||||
"is_blocked": bool(blocked_by),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -654,12 +700,13 @@ async def create_task(
|
||||
response_model=TaskRead,
|
||||
responses={409: {"model": BlockedTaskError}},
|
||||
)
|
||||
async def update_task(
|
||||
async def update_task( # noqa: C901, PLR0912, PLR0915
|
||||
payload: TaskUpdate,
|
||||
task: Task = Depends(get_task_or_404),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
actor: ActorContext = Depends(require_admin_or_agent),
|
||||
task: Task = TASK_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
actor: ActorContext = ACTOR_DEP,
|
||||
) -> TaskRead:
|
||||
"""Update task status, assignment, comment, and dependency state."""
|
||||
if task.board_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
@@ -676,7 +723,9 @@ async def update_task(
|
||||
previous_assigned = task.assigned_agent_id
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
comment = updates.pop("comment", None)
|
||||
depends_on_task_ids = cast(list[UUID] | None, updates.pop("depends_on_task_ids", None))
|
||||
depends_on_task_ids = cast(
|
||||
list[UUID] | None, updates.pop("depends_on_task_ids", None),
|
||||
)
|
||||
|
||||
requested_fields = set(updates)
|
||||
if comment is not None:
|
||||
@@ -685,7 +734,9 @@ async def update_task(
|
||||
requested_fields.add("depends_on_task_ids")
|
||||
|
||||
async def _current_dep_ids() -> list[UUID]:
|
||||
deps_map = await dependency_ids_by_task_id(session, board_id=board_id, task_ids=[task.id])
|
||||
deps_map = await dependency_ids_by_task_id(
|
||||
session, board_id=board_id, task_ids=[task.id],
|
||||
)
|
||||
return deps_map.get(task.id, [])
|
||||
|
||||
async def _blocked_by(dep_ids: Sequence[UUID]) -> list[UUID]:
|
||||
@@ -696,16 +747,20 @@ async def update_task(
|
||||
board_id=board_id,
|
||||
dependency_ids=list(dep_ids),
|
||||
)
|
||||
return blocked_by_dependency_ids(dependency_ids=list(dep_ids), status_by_id=dep_status)
|
||||
return blocked_by_dependency_ids(
|
||||
dependency_ids=list(dep_ids), status_by_id=dep_status,
|
||||
)
|
||||
|
||||
# Lead agent: delegation only (assign/unassign, resolve review, manage dependencies).
|
||||
# Lead agent: delegation only.
|
||||
# Assign/unassign, resolve review, and manage dependencies.
|
||||
if actor.actor_type == "agent" and actor.agent and actor.agent.is_board_lead:
|
||||
allowed_fields = {"assigned_agent_id", "status", "depends_on_task_ids"}
|
||||
if comment is not None or not requested_fields.issubset(allowed_fields):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=(
|
||||
"Board leads can only assign/unassign tasks, update dependencies, or resolve review tasks."
|
||||
"Board leads can only assign/unassign tasks, update "
|
||||
"dependencies, or resolve review tasks."
|
||||
),
|
||||
)
|
||||
|
||||
@@ -745,7 +800,11 @@ async def update_task(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Board leads cannot assign tasks to themselves.",
|
||||
)
|
||||
if agent.board_id and task.board_id and agent.board_id != task.board_id:
|
||||
if (
|
||||
agent.board_id
|
||||
and task.board_id
|
||||
and agent.board_id != task.board_id
|
||||
):
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
|
||||
task.assigned_agent_id = agent.id
|
||||
else:
|
||||
@@ -755,12 +814,18 @@ async def update_task(
|
||||
if task.status != "review":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Board leads can only change status when a task is in review.",
|
||||
detail=(
|
||||
"Board leads can only change status when a task is "
|
||||
"in review."
|
||||
),
|
||||
)
|
||||
if updates["status"] not in {"done", "inbox"}:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Board leads can only move review tasks to done or inbox.",
|
||||
detail=(
|
||||
"Board leads can only move review tasks to done "
|
||||
"or inbox."
|
||||
),
|
||||
)
|
||||
if updates["status"] == "inbox":
|
||||
task.assigned_agent_id = None
|
||||
@@ -793,7 +858,9 @@ async def update_task(
|
||||
await session.refresh(task)
|
||||
|
||||
if task.assigned_agent_id and task.assigned_agent_id != previous_assigned:
|
||||
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
|
||||
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(
|
||||
session,
|
||||
)
|
||||
if assigned_agent:
|
||||
board = (
|
||||
await Board.objects.by_id(task.board_id).first(session)
|
||||
@@ -817,14 +884,18 @@ async def update_task(
|
||||
"depends_on_task_ids": dep_ids,
|
||||
"blocked_by_task_ids": blocked_ids,
|
||||
"is_blocked": bool(blocked_ids),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Non-lead agent: can only change status + comment, and cannot start blocked tasks.
|
||||
if actor.actor_type == "agent":
|
||||
if actor.agent and actor.agent.board_id and task.board_id:
|
||||
if actor.agent.board_id != task.board_id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
if (
|
||||
actor.agent
|
||||
and actor.agent.board_id
|
||||
and task.board_id
|
||||
and actor.agent.board_id != task.board_id
|
||||
):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
allowed_fields = {"status", "comment"}
|
||||
if depends_on_task_ids is not None or not set(updates).issubset(allowed_fields):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
@@ -858,14 +929,16 @@ async def update_task(
|
||||
)
|
||||
|
||||
effective_deps = (
|
||||
admin_normalized_deps if admin_normalized_deps is not None else await _current_dep_ids()
|
||||
admin_normalized_deps
|
||||
if admin_normalized_deps is not None
|
||||
else await _current_dep_ids()
|
||||
)
|
||||
blocked_ids = await _blocked_by(effective_deps)
|
||||
|
||||
target_status = cast(str, updates.get("status", task.status))
|
||||
if blocked_ids and not (task.status == "done" and target_status == "done"):
|
||||
# Blocked tasks cannot be assigned or moved out of inbox. If the task is already in
|
||||
# flight, force it back to inbox and unassign it.
|
||||
# Blocked tasks cannot be assigned or moved out of inbox.
|
||||
# If the task is already in flight, force it back to inbox and unassign it.
|
||||
task.status = "inbox"
|
||||
task.assigned_agent_id = None
|
||||
task.in_progress_at = None
|
||||
@@ -910,7 +983,9 @@ async def update_task(
|
||||
event_type="task.comment",
|
||||
message=comment,
|
||||
task_id=task.id,
|
||||
agent_id=actor.agent.id if actor.actor_type == "agent" and actor.agent else None,
|
||||
agent_id=actor.agent.id
|
||||
if actor.actor_type == "agent" and actor.agent
|
||||
else None,
|
||||
)
|
||||
session.add(event)
|
||||
await session.commit()
|
||||
@@ -921,7 +996,9 @@ async def update_task(
|
||||
else:
|
||||
event_type = "task.updated"
|
||||
message = f"Task updated: {task.title}."
|
||||
actor_agent_id = actor.agent.id if actor.actor_type == "agent" and actor.agent else None
|
||||
actor_agent_id = (
|
||||
actor.agent.id if actor.actor_type == "agent" and actor.agent else None
|
||||
)
|
||||
record_activity(
|
||||
session,
|
||||
event_type=event_type,
|
||||
@@ -938,23 +1015,34 @@ async def update_task(
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
if task.status == "inbox" and task.assigned_agent_id is None:
|
||||
if previous_status != "inbox" or previous_assigned is not None:
|
||||
board = (
|
||||
await Board.objects.by_id(task.board_id).first(session) if task.board_id else None
|
||||
if (
|
||||
task.status == "inbox"
|
||||
and task.assigned_agent_id is None
|
||||
and (previous_status != "inbox" or previous_assigned is not None)
|
||||
):
|
||||
board = (
|
||||
await Board.objects.by_id(task.board_id).first(session)
|
||||
if task.board_id
|
||||
else None
|
||||
)
|
||||
if board:
|
||||
await _notify_lead_on_task_unassigned(
|
||||
session=session,
|
||||
board=board,
|
||||
task=task,
|
||||
)
|
||||
if board:
|
||||
await _notify_lead_on_task_unassigned(
|
||||
session=session,
|
||||
board=board,
|
||||
task=task,
|
||||
)
|
||||
if task.assigned_agent_id and task.assigned_agent_id != previous_assigned:
|
||||
if actor.actor_type == "agent" and actor.agent and task.assigned_agent_id == actor.agent.id:
|
||||
if (
|
||||
actor.actor_type == "agent"
|
||||
and actor.agent
|
||||
and task.assigned_agent_id == actor.agent.id
|
||||
):
|
||||
# Don't notify the actor about their own assignment.
|
||||
pass
|
||||
else:
|
||||
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
|
||||
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(
|
||||
session,
|
||||
)
|
||||
if assigned_agent:
|
||||
board = (
|
||||
await Board.objects.by_id(task.board_id).first(session)
|
||||
@@ -978,16 +1066,17 @@ async def update_task(
|
||||
"depends_on_task_ids": dep_ids,
|
||||
"blocked_by_task_ids": blocked_ids,
|
||||
"is_blocked": bool(blocked_ids),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{task_id}", response_model=OkResponse)
|
||||
async def delete_task(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
task: Task = Depends(get_task_or_404),
|
||||
auth: AuthContext = Depends(require_admin_auth),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
task: Task = TASK_DEP,
|
||||
auth: AuthContext = ADMIN_AUTH_DEP,
|
||||
) -> OkResponse:
|
||||
"""Delete a task and related records."""
|
||||
if task.board_id is None:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
board = await Board.objects.by_id(task.board_id).first(session)
|
||||
@@ -997,12 +1086,14 @@ async def delete_task(
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
await require_board_access(session, user=auth.user, board=board, write=True)
|
||||
await crud.delete_where(
|
||||
session, ActivityEvent, col(ActivityEvent.task_id) == task.id, commit=False
|
||||
session, ActivityEvent, col(ActivityEvent.task_id) == task.id, commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session, TaskFingerprint, col(TaskFingerprint.task_id) == task.id, commit=False
|
||||
session, TaskFingerprint, col(TaskFingerprint.task_id) == task.id, commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session, Approval, col(Approval.task_id) == task.id, commit=False,
|
||||
)
|
||||
await crud.delete_where(session, Approval, col(Approval.task_id) == task.id, commit=False)
|
||||
await crud.delete_where(
|
||||
session,
|
||||
TaskDependency,
|
||||
@@ -1017,11 +1108,14 @@ async def delete_task(
|
||||
return OkResponse()
|
||||
|
||||
|
||||
@router.get("/{task_id}/comments", response_model=DefaultLimitOffsetPage[TaskCommentRead])
|
||||
@router.get(
|
||||
"/{task_id}/comments", response_model=DefaultLimitOffsetPage[TaskCommentRead],
|
||||
)
|
||||
async def list_task_comments(
|
||||
task: Task = Depends(get_task_or_404),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
task: Task = TASK_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
) -> DefaultLimitOffsetPage[TaskCommentRead]:
|
||||
"""List comments for a task in chronological order."""
|
||||
statement = (
|
||||
select(ActivityEvent)
|
||||
.where(col(ActivityEvent.task_id) == task.id)
|
||||
@@ -1032,12 +1126,13 @@ async def list_task_comments(
|
||||
|
||||
|
||||
@router.post("/{task_id}/comments", response_model=TaskCommentRead)
|
||||
async def create_task_comment(
|
||||
async def create_task_comment( # noqa: C901, PLR0912
|
||||
payload: TaskCommentCreate,
|
||||
task: Task = Depends(get_task_or_404),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
actor: ActorContext = Depends(require_admin_or_agent),
|
||||
task: Task = TASK_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
actor: ActorContext = ACTOR_DEP,
|
||||
) -> ActivityEvent:
|
||||
"""Create a task comment and notify relevant agents."""
|
||||
if task.board_id is None:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
if actor.actor_type == "user" and actor.user is not None:
|
||||
@@ -1045,22 +1140,28 @@ async def create_task_comment(
|
||||
if board is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
await require_board_access(session, user=actor.user, board=board, write=True)
|
||||
if actor.actor_type == "agent" and actor.agent:
|
||||
if actor.agent.is_board_lead and task.status != "review":
|
||||
if not await _lead_was_mentioned(session, task, actor.agent) and not _lead_created_task(
|
||||
task, actor.agent
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=(
|
||||
"Board leads can only comment during review, when mentioned, or on tasks they created."
|
||||
),
|
||||
)
|
||||
if (
|
||||
actor.actor_type == "agent"
|
||||
and actor.agent
|
||||
and actor.agent.is_board_lead
|
||||
and task.status != "review"
|
||||
and not await _lead_was_mentioned(session, task, actor.agent)
|
||||
and not _lead_created_task(task, actor.agent)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=(
|
||||
"Board leads can only comment during review, when mentioned, "
|
||||
"or on tasks they created."
|
||||
),
|
||||
)
|
||||
event = ActivityEvent(
|
||||
event_type="task.comment",
|
||||
message=payload.message,
|
||||
task_id=task.id,
|
||||
agent_id=actor.agent.id if actor.actor_type == "agent" and actor.agent else None,
|
||||
agent_id=actor.agent.id
|
||||
if actor.actor_type == "agent" and actor.agent
|
||||
else None,
|
||||
)
|
||||
session.add(event)
|
||||
await session.commit()
|
||||
@@ -1072,17 +1173,27 @@ async def create_task_comment(
|
||||
if matches_agent_mention(agent, mention_names):
|
||||
targets[agent.id] = agent
|
||||
if not mention_names and task.assigned_agent_id:
|
||||
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
|
||||
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(
|
||||
session,
|
||||
)
|
||||
if assigned_agent:
|
||||
targets[assigned_agent.id] = assigned_agent
|
||||
if actor.actor_type == "agent" and actor.agent:
|
||||
targets.pop(actor.agent.id, None)
|
||||
if targets:
|
||||
board = await Board.objects.by_id(task.board_id).first(session) if task.board_id else None
|
||||
board = (
|
||||
await Board.objects.by_id(task.board_id).first(session)
|
||||
if task.board_id
|
||||
else None
|
||||
)
|
||||
config = await _gateway_config(session, board) if board else None
|
||||
if board and config:
|
||||
snippet = _truncate_snippet(payload.message)
|
||||
actor_name = actor.agent.name if actor.actor_type == "agent" and actor.agent else "User"
|
||||
actor_name = (
|
||||
actor.agent.name
|
||||
if actor.actor_type == "agent" and actor.agent
|
||||
else "User"
|
||||
)
|
||||
for agent in targets.values():
|
||||
if not agent.openclaw_session_id:
|
||||
continue
|
||||
@@ -1101,15 +1212,14 @@ async def create_task_comment(
|
||||
f"From: {actor_name}\n\n"
|
||||
f"{action_line}\n\n"
|
||||
f"Comment:\n{snippet}\n\n"
|
||||
"If you are mentioned but not assigned, reply in the task thread but do not change task status."
|
||||
"If you are mentioned but not assigned, reply in the task "
|
||||
"thread but do not change task status."
|
||||
)
|
||||
try:
|
||||
with suppress(OpenClawGatewayError):
|
||||
await _send_agent_task_message(
|
||||
session_key=agent.openclaw_session_id,
|
||||
config=config,
|
||||
agent_name=agent.name,
|
||||
message=message,
|
||||
)
|
||||
except OpenClawGatewayError:
|
||||
pass
|
||||
return event
|
||||
|
||||
@@ -1,18 +1,28 @@
|
||||
"""User self-service API endpoints for profile retrieval and updates."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.auth import AuthContext, get_auth_context
|
||||
from app.db.session import get_session
|
||||
from app.models.users import User
|
||||
from app.schemas.users import UserRead, UserUpdate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.users import User
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
AUTH_CONTEXT_DEP = Depends(get_auth_context)
|
||||
SESSION_DEP = Depends(get_session)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserRead)
|
||||
async def get_me(auth: AuthContext = Depends(get_auth_context)) -> UserRead:
|
||||
async def get_me(auth: AuthContext = AUTH_CONTEXT_DEP) -> UserRead:
|
||||
"""Return the authenticated user's current profile payload."""
|
||||
if auth.actor_type != "user" or auth.user is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
return UserRead.model_validate(auth.user)
|
||||
@@ -21,9 +31,10 @@ async def get_me(auth: AuthContext = Depends(get_auth_context)) -> UserRead:
|
||||
@router.patch("/me", response_model=UserRead)
|
||||
async def update_me(
|
||||
payload: UserUpdate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
auth: AuthContext = AUTH_CONTEXT_DEP,
|
||||
) -> UserRead:
|
||||
"""Apply partial profile updates for the authenticated user."""
|
||||
if auth.actor_type != "user" or auth.user is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Core utilities and configuration for the backend service."""
|
||||
|
||||
@@ -1,33 +1,44 @@
|
||||
"""Agent authentication helpers for token-backed API access."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
from fastapi import Depends, Header, HTTPException, Request, status
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.agent_tokens import verify_agent_token
|
||||
from app.core.time import utcnow
|
||||
from app.db.session import get_session
|
||||
from app.models.agents import Agent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_LAST_SEEN_TOUCH_INTERVAL = timedelta(seconds=30)
|
||||
_SAFE_METHODS = frozenset({"GET", "HEAD", "OPTIONS"})
|
||||
SESSION_DEP = Depends(get_session)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentAuthContext:
|
||||
"""Authenticated actor payload for agent-originated requests."""
|
||||
|
||||
actor_type: Literal["agent"]
|
||||
agent: Agent
|
||||
|
||||
|
||||
async def _find_agent_for_token(session: AsyncSession, token: str) -> Agent | None:
|
||||
agents = list(await session.exec(select(Agent).where(col(Agent.agent_token_hash).is_not(None))))
|
||||
agents = list(
|
||||
await session.exec(
|
||||
select(Agent).where(col(Agent.agent_token_hash).is_not(None)),
|
||||
),
|
||||
)
|
||||
for agent in agents:
|
||||
if agent.agent_token_hash and verify_agent_token(token, agent.agent_token_hash):
|
||||
return agent
|
||||
@@ -65,9 +76,11 @@ async def _touch_agent_presence(
|
||||
calls (task comments, memory updates, etc). Touch presence so the UI reflects
|
||||
real activity even if the heartbeat loop isn't running.
|
||||
"""
|
||||
|
||||
now = utcnow()
|
||||
if agent.last_seen_at is not None and now - agent.last_seen_at < _LAST_SEEN_TOUCH_INTERVAL:
|
||||
if (
|
||||
agent.last_seen_at is not None
|
||||
and now - agent.last_seen_at < _LAST_SEEN_TOUCH_INTERVAL
|
||||
):
|
||||
return
|
||||
|
||||
agent.last_seen_at = now
|
||||
@@ -86,9 +99,14 @@ async def get_agent_auth_context(
|
||||
request: Request,
|
||||
agent_token: str | None = Header(default=None, alias="X-Agent-Token"),
|
||||
authorization: str | None = Header(default=None, alias="Authorization"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
) -> AgentAuthContext:
|
||||
resolved = _resolve_agent_token(agent_token, authorization, accept_authorization=True)
|
||||
"""Require and validate agent auth token from request headers."""
|
||||
resolved = _resolve_agent_token(
|
||||
agent_token,
|
||||
authorization,
|
||||
accept_authorization=True,
|
||||
)
|
||||
if not resolved:
|
||||
logger.warning(
|
||||
"agent auth missing token path=%s x_agent=%s authorization=%s",
|
||||
@@ -113,8 +131,9 @@ async def get_agent_auth_context_optional(
|
||||
request: Request,
|
||||
agent_token: str | None = Header(default=None, alias="X-Agent-Token"),
|
||||
authorization: str | None = Header(default=None, alias="Authorization"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
) -> AgentAuthContext | None:
|
||||
"""Optionally resolve agent auth context from `X-Agent-Token` only."""
|
||||
resolved = _resolve_agent_token(
|
||||
agent_token,
|
||||
authorization,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Token generation and verification helpers for agent authentication."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
@@ -10,6 +12,7 @@ SALT_BYTES = 16
|
||||
|
||||
|
||||
def generate_agent_token() -> str:
|
||||
"""Generate a new URL-safe random token for an agent."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
@@ -23,12 +26,14 @@ def _b64decode(value: str) -> bytes:
|
||||
|
||||
|
||||
def hash_agent_token(token: str) -> str:
|
||||
"""Hash an agent token using PBKDF2-HMAC-SHA256 with a random salt."""
|
||||
salt = secrets.token_bytes(SALT_BYTES)
|
||||
digest = hashlib.pbkdf2_hmac("sha256", token.encode("utf-8"), salt, ITERATIONS)
|
||||
return f"pbkdf2_sha256${ITERATIONS}${_b64encode(salt)}${_b64encode(digest)}"
|
||||
|
||||
|
||||
def verify_agent_token(token: str, stored_hash: str) -> bool:
|
||||
"""Verify a plaintext token against a stored PBKDF2 hash representation."""
|
||||
try:
|
||||
algorithm, iterations, salt_b64, digest_b64 = stored_hash.split("$")
|
||||
except ValueError:
|
||||
@@ -41,5 +46,10 @@ def verify_agent_token(token: str, stored_hash: str) -> bool:
|
||||
return False
|
||||
salt = _b64decode(salt_b64)
|
||||
expected_digest = _b64decode(digest_b64)
|
||||
candidate = hashlib.pbkdf2_hmac("sha256", token.encode("utf-8"), salt, iterations_int)
|
||||
candidate = hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
token.encode("utf-8"),
|
||||
salt,
|
||||
iterations_int,
|
||||
)
|
||||
return hmac.compare_digest(candidate, expected_digest)
|
||||
|
||||
@@ -1,32 +1,42 @@
|
||||
"""User authentication helpers backed by Clerk JWT verification."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from fastapi_clerk_auth import ClerkConfig, ClerkHTTPBearer
|
||||
from fastapi_clerk_auth import HTTPAuthorizationCredentials as ClerkCredentials
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db import crud
|
||||
from app.db.session import get_session
|
||||
from app.models.users import User
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
security = HTTPBearer(auto_error=False)
|
||||
SECURITY_DEP = Depends(security)
|
||||
SESSION_DEP = Depends(get_session)
|
||||
CLERK_JWKS_URL_REQUIRED_ERROR = "CLERK_JWKS_URL is not set."
|
||||
|
||||
|
||||
class ClerkTokenPayload(BaseModel):
|
||||
"""JWT claims payload shape required from Clerk tokens."""
|
||||
|
||||
sub: str
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _build_clerk_http_bearer(auto_error: bool) -> ClerkHTTPBearer:
|
||||
def _build_clerk_http_bearer(*, auto_error: bool) -> ClerkHTTPBearer:
|
||||
"""Create and cache the Clerk HTTP bearer guard."""
|
||||
if not settings.clerk_jwks_url:
|
||||
raise RuntimeError("CLERK_JWKS_URL is not set.")
|
||||
raise RuntimeError(CLERK_JWKS_URL_REQUIRED_ERROR)
|
||||
clerk_config = ClerkConfig(
|
||||
jwks_url=settings.clerk_jwks_url,
|
||||
verify_iat=settings.clerk_verify_iat,
|
||||
@@ -37,12 +47,15 @@ def _build_clerk_http_bearer(auto_error: bool) -> ClerkHTTPBearer:
|
||||
|
||||
@dataclass
|
||||
class AuthContext:
|
||||
"""Authenticated user context resolved from inbound auth headers."""
|
||||
|
||||
actor_type: Literal["user"]
|
||||
user: User | None = None
|
||||
|
||||
|
||||
def _resolve_clerk_auth(
|
||||
request: Request, fallback: ClerkCredentials | None
|
||||
request: Request,
|
||||
fallback: ClerkCredentials | None,
|
||||
) -> ClerkCredentials | None:
|
||||
auth_data = getattr(request.state, "clerk_auth", None)
|
||||
if isinstance(auth_data, ClerkCredentials):
|
||||
@@ -59,9 +72,10 @@ def _parse_subject(auth_data: ClerkCredentials | None) -> str | None:
|
||||
|
||||
async def get_auth_context(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(security),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
credentials: HTTPAuthorizationCredentials | None = SECURITY_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
) -> AuthContext:
|
||||
"""Resolve required authenticated user context from Clerk JWT headers."""
|
||||
if credentials is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
@@ -109,9 +123,10 @@ async def get_auth_context(
|
||||
|
||||
async def get_auth_context_optional(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(security),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
credentials: HTTPAuthorizationCredentials | None = SECURITY_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
) -> AuthContext | None:
|
||||
"""Resolve user context if available, otherwise return `None`."""
|
||||
if request.headers.get("X-Agent-Token"):
|
||||
return None
|
||||
if credentials is None:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Application settings and environment configuration loading."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
@@ -11,6 +13,8 @@ DEFAULT_ENV_FILE = BACKEND_ROOT / ".env"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Typed runtime configuration sourced from environment variables."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
# Load `backend/.env` regardless of current working directory.
|
||||
# (Important when running uvicorn from repo root or via a process manager.)
|
||||
@@ -32,8 +36,8 @@ class Settings(BaseSettings):
|
||||
base_url: str = ""
|
||||
|
||||
# Optional: local directory where the backend is allowed to write "preserved" agent
|
||||
# workspace files (e.g. USER.md/SELF.md/MEMORY.md). If empty, local writes are disabled
|
||||
# and provisioning relies on the gateway API.
|
||||
# workspace files (e.g. USER.md/SELF.md/MEMORY.md). If empty, local
|
||||
# writes are disabled and provisioning relies on the gateway API.
|
||||
#
|
||||
# Security note: do NOT point this at arbitrary system paths in production.
|
||||
local_agent_workspace_root: str = ""
|
||||
@@ -48,8 +52,8 @@ class Settings(BaseSettings):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _defaults(self) -> Self:
|
||||
# In dev, default to applying Alembic migrations at startup to avoid schema drift
|
||||
# (e.g. missing newly-added columns).
|
||||
# In dev, default to applying Alembic migrations at startup to avoid
|
||||
# schema drift (e.g. missing newly-added columns).
|
||||
if "db_auto_migrate" not in self.model_fields_set and self.environment == "dev":
|
||||
self.db_auto_migrate = True
|
||||
return self
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
"""Utilities for parsing human-readable duration schedule strings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
_DURATION_RE = re.compile(r"^(?P<num>[1-9]\\d*)\\s*(?P<unit>[smhdw])$", flags=re.IGNORECASE)
|
||||
_DURATION_RE = re.compile(
|
||||
r"^(?P<num>[1-9]\\d*)\\s*(?P<unit>[smhdw])$",
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
|
||||
_MULTIPLIERS: dict[str, int] = {
|
||||
"s": 1,
|
||||
@@ -11,26 +16,36 @@ _MULTIPLIERS: dict[str, int] = {
|
||||
"d": 60 * 60 * 24,
|
||||
"w": 60 * 60 * 24 * 7,
|
||||
}
|
||||
_MAX_SCHEDULE_SECONDS = 60 * 60 * 24 * 365 * 10
|
||||
|
||||
_ERR_SCHEDULE_REQUIRED = "schedule is required"
|
||||
_ERR_SCHEDULE_INVALID = (
|
||||
'Invalid schedule. Expected format like "10m", "1h", "2d", "1w".'
|
||||
)
|
||||
_ERR_SCHEDULE_NONPOSITIVE = "Schedule must be greater than 0."
|
||||
_ERR_SCHEDULE_TOO_LARGE = "Schedule is too large (max 10 years)."
|
||||
|
||||
|
||||
def normalize_every(value: str) -> str:
|
||||
"""Normalize schedule string to lower-case compact unit form."""
|
||||
normalized = value.strip().lower().replace(" ", "")
|
||||
if not normalized:
|
||||
raise ValueError("schedule is required")
|
||||
raise ValueError(_ERR_SCHEDULE_REQUIRED)
|
||||
return normalized
|
||||
|
||||
|
||||
def parse_every_to_seconds(value: str) -> int:
|
||||
"""Parse compact schedule syntax into a number of seconds."""
|
||||
normalized = normalize_every(value)
|
||||
match = _DURATION_RE.match(normalized)
|
||||
if not match:
|
||||
raise ValueError('Invalid schedule. Expected format like "10m", "1h", "2d", "1w".')
|
||||
raise ValueError(_ERR_SCHEDULE_INVALID)
|
||||
num = int(match.group("num"))
|
||||
unit = match.group("unit").lower()
|
||||
seconds = num * _MULTIPLIERS[unit]
|
||||
if seconds <= 0:
|
||||
raise ValueError("Schedule must be greater than 0.")
|
||||
raise ValueError(_ERR_SCHEDULE_NONPOSITIVE)
|
||||
# Prevent accidental absurd schedules (e.g. 999999999d).
|
||||
if seconds > 60 * 60 * 24 * 365 * 10:
|
||||
raise ValueError("Schedule is too large (max 10 years).")
|
||||
if seconds > _MAX_SCHEDULE_SECONDS:
|
||||
raise ValueError(_ERR_SCHEDULE_TOO_LARGE)
|
||||
return seconds
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""Global exception handlers and request-id middleware for FastAPI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, Final, cast
|
||||
from typing import TYPE_CHECKING, Any, Final, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
@@ -10,7 +12,9 @@ from fastapi.exceptions import RequestValidationError, ResponseValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from starlette.responses import Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,12 +24,16 @@ ExceptionHandler = Callable[[Request, Exception], Response | Awaitable[Response]
|
||||
|
||||
|
||||
class RequestIdMiddleware:
|
||||
"""ASGI middleware that ensures every request has a request-id."""
|
||||
|
||||
def __init__(self, app: ASGIApp, *, header_name: str = REQUEST_ID_HEADER) -> None:
|
||||
"""Initialize middleware with app instance and header name."""
|
||||
self._app = app
|
||||
self._header_name = header_name
|
||||
self._header_name_bytes = header_name.lower().encode("latin-1")
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""Inject request-id into request state and response headers."""
|
||||
if scope["type"] != "http":
|
||||
await self._app(scope, receive, send)
|
||||
return
|
||||
@@ -36,8 +44,11 @@ class RequestIdMiddleware:
|
||||
if message["type"] == "http.response.start":
|
||||
# Starlette uses `list[tuple[bytes, bytes]]` here.
|
||||
headers: list[tuple[bytes, bytes]] = message.setdefault("headers", [])
|
||||
if not any(key.lower() == self._header_name_bytes for key, _ in headers):
|
||||
headers.append((self._header_name_bytes, request_id.encode("latin-1")))
|
||||
if not any(
|
||||
key.lower() == self._header_name_bytes for key, _ in headers
|
||||
):
|
||||
request_id_bytes = request_id.encode("latin-1")
|
||||
headers.append((self._header_name_bytes, request_id_bytes))
|
||||
await send(message)
|
||||
|
||||
await self._app(scope, receive, send_with_request_id)
|
||||
@@ -62,8 +73,10 @@ class RequestIdMiddleware:
|
||||
|
||||
|
||||
def install_error_handling(app: FastAPI) -> None:
|
||||
"""Install middleware and exception handlers on the FastAPI app."""
|
||||
# Important: add request-id middleware last so it's the outermost middleware.
|
||||
# This ensures it still runs even if another middleware (e.g. CORS preflight) returns early.
|
||||
# This ensures it still runs even if another middleware
|
||||
# (e.g. CORS preflight) returns early.
|
||||
app.add_middleware(RequestIdMiddleware)
|
||||
|
||||
app.add_exception_handler(
|
||||
@@ -88,7 +101,7 @@ def _get_request_id(request: Request) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _error_payload(*, detail: Any, request_id: str | None) -> dict[str, Any]:
|
||||
def _error_payload(*, detail: object, request_id: str | None) -> dict[str, object]:
|
||||
payload: dict[str, Any] = {"detail": detail}
|
||||
if request_id:
|
||||
payload["request_id"] = request_id
|
||||
@@ -96,7 +109,8 @@ def _error_payload(*, detail: Any, request_id: str | None) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def _request_validation_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
request: Request,
|
||||
exc: RequestValidationError,
|
||||
) -> JSONResponse:
|
||||
# `RequestValidationError` is expected user input; don't log at ERROR.
|
||||
request_id = _get_request_id(request)
|
||||
@@ -107,7 +121,8 @@ async def _request_validation_handler(
|
||||
|
||||
|
||||
async def _response_validation_handler(
|
||||
request: Request, exc: ResponseValidationError
|
||||
request: Request,
|
||||
exc: ResponseValidationError,
|
||||
) -> JSONResponse:
|
||||
request_id = _get_request_id(request)
|
||||
logger.exception(
|
||||
@@ -125,7 +140,10 @@ async def _response_validation_handler(
|
||||
)
|
||||
|
||||
|
||||
async def _http_exception_handler(request: Request, exc: StarletteHTTPException) -> JSONResponse:
|
||||
async def _http_exception_handler(
|
||||
request: Request,
|
||||
exc: StarletteHTTPException,
|
||||
) -> JSONResponse:
|
||||
request_id = _get_request_id(request)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
@@ -134,11 +152,18 @@ async def _http_exception_handler(request: Request, exc: StarletteHTTPException)
|
||||
)
|
||||
|
||||
|
||||
async def _unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
async def _unhandled_exception_handler(
|
||||
request: Request,
|
||||
_exc: Exception,
|
||||
) -> JSONResponse:
|
||||
request_id = _get_request_id(request)
|
||||
logger.exception(
|
||||
"unhandled_exception",
|
||||
extra={"request_id": request_id, "method": request.method, "path": request.url.path},
|
||||
extra={
|
||||
"request_id": request_id,
|
||||
"method": request.method,
|
||||
"path": request.url.path,
|
||||
},
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Application logging configuration and formatter utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
@@ -15,7 +17,8 @@ TRACE_LEVEL = 5
|
||||
logging.addLevelName(TRACE_LEVEL, "TRACE")
|
||||
|
||||
|
||||
def _trace(self: logging.Logger, message: str, *args: Any, **kwargs: Any) -> None:
|
||||
def _trace(self: logging.Logger, message: str, *args: object, **kwargs: object) -> None:
|
||||
"""Log a TRACE-level message when the logger is TRACE-enabled."""
|
||||
if self.isEnabledFor(TRACE_LEVEL):
|
||||
self._log(TRACE_LEVEL, message, args, **kwargs)
|
||||
|
||||
@@ -52,21 +55,31 @@ _STANDARD_LOG_RECORD_ATTRS = {
|
||||
|
||||
|
||||
class AppLogFilter(logging.Filter):
|
||||
"""Inject app metadata into each log record."""
|
||||
|
||||
def __init__(self, app_name: str, version: str) -> None:
|
||||
"""Initialize the filter with fixed app and version values."""
|
||||
super().__init__()
|
||||
self._app_name = app_name
|
||||
self._version = version
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
"""Attach app metadata fields to each emitted record."""
|
||||
record.app = self._app_name
|
||||
record.version = self._version
|
||||
return True
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
"""Formatter that serializes log records as compact JSON."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
"""Render a single log record into a JSON string."""
|
||||
payload: dict[str, Any] = {
|
||||
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
|
||||
"timestamp": datetime.fromtimestamp(
|
||||
record.created,
|
||||
tz=timezone.utc,
|
||||
).isoformat(),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"message": record.getMessage(),
|
||||
@@ -88,7 +101,10 @@ class JsonFormatter(logging.Formatter):
|
||||
|
||||
|
||||
class KeyValueFormatter(logging.Formatter):
|
||||
"""Formatter that appends extra fields as `key=value` pairs."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
"""Render a log line with appended non-standard record fields."""
|
||||
base = super().format(record)
|
||||
extras = {
|
||||
key: value
|
||||
@@ -102,6 +118,8 @@ class KeyValueFormatter(logging.Formatter):
|
||||
|
||||
|
||||
class AppLogger:
|
||||
"""Centralized logging setup utility for the backend process."""
|
||||
|
||||
_configured = False
|
||||
|
||||
@classmethod
|
||||
@@ -111,10 +129,12 @@ class AppLogger:
|
||||
return level_name, TRACE_LEVEL
|
||||
if level_name.isdigit():
|
||||
return level_name, int(level_name)
|
||||
return level_name, logging._nameToLevel.get(level_name, logging.INFO)
|
||||
levels = logging.getLevelNamesMapping()
|
||||
return level_name, levels.get(level_name, logging.INFO)
|
||||
|
||||
@classmethod
|
||||
def configure(cls, *, force: bool = False) -> None:
|
||||
"""Configure root logging handlers, formatters, and library levels."""
|
||||
if cls._configured and not force:
|
||||
return
|
||||
|
||||
@@ -127,7 +147,8 @@ class AppLogger:
|
||||
formatter: logging.Formatter = JsonFormatter()
|
||||
else:
|
||||
formatter = KeyValueFormatter(
|
||||
"%(asctime)s %(levelname)s %(name)s %(message)s app=%(app)s version=%(version)s"
|
||||
"%(asctime)s %(levelname)s %(name)s %(message)s "
|
||||
"app=%(app)s version=%(version)s",
|
||||
)
|
||||
if settings.log_use_utc:
|
||||
formatter.converter = time.gmtime
|
||||
@@ -160,10 +181,12 @@ class AppLogger:
|
||||
|
||||
@classmethod
|
||||
def get_logger(cls, name: str | None = None) -> logging.Logger:
|
||||
"""Return a logger, ensuring logging has been configured."""
|
||||
if not cls._configured:
|
||||
cls.configure()
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def configure_logging() -> None:
|
||||
"""Configure global application logging once during startup."""
|
||||
AppLogger.configure()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Time-related helpers shared across backend modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
@@ -5,6 +7,5 @@ from datetime import UTC, datetime
|
||||
|
||||
def utcnow() -> datetime:
|
||||
"""Return a naive UTC datetime without using deprecated datetime.utcnow()."""
|
||||
|
||||
# Keep naive UTC values for compatibility with existing DB schema/queries.
|
||||
return datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
"""Application name and version constants."""
|
||||
|
||||
APP_NAME = "mission-control"
|
||||
APP_VERSION = "0.1.0"
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Database helpers and abstractions for backend persistence."""
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
"""Typed wrapper around fastapi-pagination for backend query helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from typing import Any, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||
|
||||
from fastapi_pagination.ext.sqlalchemy import paginate as _paginate
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import Select, SelectOfScalar
|
||||
|
||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import Select, SelectOfScalar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
Transformer = Callable[[Sequence[Any]], Sequence[Any] | Awaitable[Sequence[Any]]]
|
||||
@@ -20,8 +24,10 @@ async def paginate(
|
||||
*,
|
||||
transformer: Transformer | None = None,
|
||||
) -> DefaultLimitOffsetPage[T]:
|
||||
# fastapi-pagination is not fully typed (it returns Any), but response_model validation
|
||||
# ensures runtime correctness. Centralize casts here to keep strict mypy clean.
|
||||
"""Execute a paginated query and cast to the project page type alias."""
|
||||
# fastapi-pagination is not fully typed (it returns Any), but response_model
|
||||
# validation ensures runtime correctness. Centralize casts here to keep strict
|
||||
# mypy clean.
|
||||
return cast(
|
||||
DefaultLimitOffsetPage[T],
|
||||
await _paginate(session, statement, transformer=transformer),
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""Model manager descriptor utilities for query-set style access."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from sqlalchemy import false
|
||||
from sqlmodel import SQLModel, col
|
||||
@@ -13,41 +15,55 @@ ModelT = TypeVar("ModelT", bound=SQLModel)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelManager(Generic[ModelT]):
|
||||
"""Convenience query manager bound to a SQLModel class."""
|
||||
|
||||
model: type[ModelT]
|
||||
id_field: str = "id"
|
||||
|
||||
def all(self) -> QuerySet[ModelT]:
|
||||
"""Return an unfiltered queryset for the bound model."""
|
||||
return qs(self.model)
|
||||
|
||||
def none(self) -> QuerySet[ModelT]:
|
||||
"""Return a queryset that yields no rows."""
|
||||
return qs(self.model).filter(false())
|
||||
|
||||
def filter(self, *criteria: Any) -> QuerySet[ModelT]:
|
||||
def filter(self, *criteria: object) -> QuerySet[ModelT]:
|
||||
"""Return queryset filtered by SQL criteria expressions."""
|
||||
return self.all().filter(*criteria)
|
||||
|
||||
def where(self, *criteria: Any) -> QuerySet[ModelT]:
|
||||
def where(self, *criteria: object) -> QuerySet[ModelT]:
|
||||
"""Alias for `filter`."""
|
||||
return self.filter(*criteria)
|
||||
|
||||
def filter_by(self, **kwargs: Any) -> QuerySet[ModelT]:
|
||||
def filter_by(self, **kwargs: object) -> QuerySet[ModelT]:
|
||||
"""Return queryset filtered by model field equality values."""
|
||||
queryset = self.all()
|
||||
for field_name, value in kwargs.items():
|
||||
queryset = queryset.filter(col(getattr(self.model, field_name)) == value)
|
||||
return queryset
|
||||
|
||||
def by_id(self, obj_id: Any) -> QuerySet[ModelT]:
|
||||
def by_id(self, obj_id: object) -> QuerySet[ModelT]:
|
||||
"""Return queryset filtered by primary identifier field."""
|
||||
return self.by_field(self.id_field, obj_id)
|
||||
|
||||
def by_ids(self, obj_ids: list[Any] | tuple[Any, ...] | set[Any]) -> QuerySet[ModelT]:
|
||||
def by_ids(
|
||||
self,
|
||||
obj_ids: list[object] | tuple[object, ...] | set[object],
|
||||
) -> QuerySet[ModelT]:
|
||||
"""Return queryset filtered by a set/list/tuple of identifiers."""
|
||||
return self.by_field_in(self.id_field, obj_ids)
|
||||
|
||||
def by_field(self, field_name: str, value: Any) -> QuerySet[ModelT]:
|
||||
def by_field(self, field_name: str, value: object) -> QuerySet[ModelT]:
|
||||
"""Return queryset filtered by a single field equality check."""
|
||||
return self.filter(col(getattr(self.model, field_name)) == value)
|
||||
|
||||
def by_field_in(
|
||||
self,
|
||||
field_name: str,
|
||||
values: list[Any] | tuple[Any, ...] | set[Any],
|
||||
values: list[object] | tuple[object, ...] | set[object],
|
||||
) -> QuerySet[ModelT]:
|
||||
"""Return queryset filtered by `field IN values` semantics."""
|
||||
seq = tuple(values)
|
||||
if not seq:
|
||||
return self.none()
|
||||
@@ -55,5 +71,8 @@ class ModelManager(Generic[ModelT]):
|
||||
|
||||
|
||||
class ManagerDescriptor(Generic[ModelT]):
|
||||
"""Descriptor that exposes a model-bound `ModelManager` as `.objects`."""
|
||||
|
||||
def __get__(self, instance: object, owner: type[ModelT]) -> ModelManager[ModelT]:
|
||||
"""Return a fresh manager bound to the owning model class."""
|
||||
return ModelManager(owner)
|
||||
|
||||
@@ -1,50 +1,67 @@
|
||||
"""Lightweight immutable query-set wrapper for SQLModel statements."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
ModelT = TypeVar("ModelT")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class QuerySet(Generic[ModelT]):
|
||||
"""Composable immutable wrapper around a SQLModel scalar select statement."""
|
||||
|
||||
statement: SelectOfScalar[ModelT]
|
||||
|
||||
def filter(self, *criteria: Any) -> QuerySet[ModelT]:
|
||||
def filter(self, *criteria: object) -> QuerySet[ModelT]:
|
||||
"""Return a new queryset with additional SQL criteria."""
|
||||
return replace(self, statement=self.statement.where(*criteria))
|
||||
|
||||
def where(self, *criteria: Any) -> QuerySet[ModelT]:
|
||||
def where(self, *criteria: object) -> QuerySet[ModelT]:
|
||||
"""Alias for `filter` to mirror SQLAlchemy naming."""
|
||||
return self.filter(*criteria)
|
||||
|
||||
def filter_by(self, **kwargs: Any) -> QuerySet[ModelT]:
|
||||
def filter_by(self, **kwargs: object) -> QuerySet[ModelT]:
|
||||
"""Return a new queryset filtered by keyword-equality criteria."""
|
||||
statement = self.statement.filter_by(**kwargs)
|
||||
return replace(self, statement=statement)
|
||||
|
||||
def order_by(self, *ordering: Any) -> QuerySet[ModelT]:
|
||||
def order_by(self, *ordering: object) -> QuerySet[ModelT]:
|
||||
"""Return a new queryset with ordering clauses applied."""
|
||||
return replace(self, statement=self.statement.order_by(*ordering))
|
||||
|
||||
def limit(self, value: int) -> QuerySet[ModelT]:
|
||||
"""Return a new queryset with a SQL row limit."""
|
||||
return replace(self, statement=self.statement.limit(value))
|
||||
|
||||
def offset(self, value: int) -> QuerySet[ModelT]:
|
||||
"""Return a new queryset with a SQL row offset."""
|
||||
return replace(self, statement=self.statement.offset(value))
|
||||
|
||||
async def all(self, session: AsyncSession) -> list[ModelT]:
|
||||
"""Execute and return all rows for the current queryset."""
|
||||
return list(await session.exec(self.statement))
|
||||
|
||||
async def first(self, session: AsyncSession) -> ModelT | None:
|
||||
"""Execute and return the first row, if available."""
|
||||
return (await session.exec(self.statement)).first()
|
||||
|
||||
async def one_or_none(self, session: AsyncSession) -> ModelT | None:
|
||||
"""Execute and return one row or `None`."""
|
||||
return (await session.exec(self.statement)).one_or_none()
|
||||
|
||||
async def exists(self, session: AsyncSession) -> bool:
|
||||
"""Return whether the queryset yields at least one row."""
|
||||
return await self.limit(1).first(session) is not None
|
||||
|
||||
|
||||
def qs(model: type[ModelT]) -> QuerySet[ModelT]:
|
||||
"""Create a base queryset for a SQLModel class."""
|
||||
return QuerySet(select(model))
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""Database engine, session factory, and startup migration helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import anyio
|
||||
from alembic import command
|
||||
@@ -15,6 +17,9 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from app import models as _models
|
||||
from app.core.config import settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
# Import model modules so SQLModel metadata is fully registered at startup.
|
||||
_MODEL_REGISTRY = _models
|
||||
|
||||
@@ -48,12 +53,14 @@ def _alembic_config() -> Config:
|
||||
|
||||
|
||||
def run_migrations() -> None:
|
||||
"""Apply Alembic migrations to the latest revision."""
|
||||
logger.info("Running database migrations.")
|
||||
command.upgrade(_alembic_config(), "head")
|
||||
logger.info("Database migrations complete.")
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""Initialize database schema, running migrations when configured."""
|
||||
if settings.db_auto_migrate:
|
||||
versions_dir = Path(__file__).resolve().parents[2] / "migrations" / "versions"
|
||||
if any(versions_dir.glob("*.py")):
|
||||
@@ -67,6 +74,7 @@ async def init_db() -> None:
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Yield a request-scoped async DB session with safe rollback on errors."""
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
yield session
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""External system integration clients and protocol adapters."""
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""OpenClaw gateway protocol constants shared across integration layers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
PROTOCOL_VERSION = 3
|
||||
@@ -116,4 +118,5 @@ GATEWAY_EVENTS_SET = frozenset(GATEWAY_EVENTS)
|
||||
|
||||
|
||||
def is_known_gateway_method(method: str) -> bool:
|
||||
"""Return whether a method name is part of the known base gateway methods."""
|
||||
return method in GATEWAY_METHODS_SET
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""FastAPI application entrypoint and router wiring for the backend."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -29,11 +31,15 @@ from app.core.error_handling import install_error_handling
|
||||
from app.core.logging import configure_logging
|
||||
from app.db.session import init_db
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
configure_logging()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
|
||||
"""Initialize application resources before serving requests."""
|
||||
await init_db()
|
||||
yield
|
||||
|
||||
@@ -55,16 +61,19 @@ install_error_handling(app)
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> dict[str, bool]:
|
||||
"""Lightweight liveness probe endpoint."""
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@app.get("/healthz")
|
||||
def healthz() -> dict[str, bool]:
|
||||
"""Alias liveness probe endpoint for platform compatibility."""
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@app.get("/readyz")
|
||||
def readyz() -> dict[str, bool]:
|
||||
"""Readiness probe endpoint for service orchestration checks."""
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Model exports for SQLAlchemy/SQLModel metadata discovery."""
|
||||
|
||||
from app.models.activity_events import ActivityEvent
|
||||
from app.models.agents import Agent
|
||||
from app.models.approvals import Approval
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Activity event model persisted for audit and feed use-cases."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field
|
||||
@@ -10,6 +12,8 @@ from app.models.base import QueryModel
|
||||
|
||||
|
||||
class ActivityEvent(QueryModel, table=True):
|
||||
"""Discrete activity event tied to tasks and agents."""
|
||||
|
||||
__tablename__ = "activity_events"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Agent model representing autonomous actors assigned to boards."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
@@ -12,6 +14,8 @@ from app.models.base import QueryModel
|
||||
|
||||
|
||||
class Agent(QueryModel, table=True):
|
||||
"""Agent configuration and lifecycle state persisted in the database."""
|
||||
|
||||
__tablename__ = "agents"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
@@ -20,8 +24,14 @@ class Agent(QueryModel, table=True):
|
||||
status: str = Field(default="provisioning", index=True)
|
||||
openclaw_session_id: str | None = Field(default=None, index=True)
|
||||
agent_token_hash: str | None = Field(default=None, index=True)
|
||||
heartbeat_config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
identity_profile: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
heartbeat_config: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(JSON),
|
||||
)
|
||||
identity_profile: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(JSON),
|
||||
)
|
||||
identity_template: str | None = Field(default=None, sa_column=Column(Text))
|
||||
soul_template: str | None = Field(default=None, sa_column=Column(Text))
|
||||
provision_requested_at: datetime | None = Field(default=None)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Approval model storing pending and resolved approval actions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import JSON, Column
|
||||
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
|
||||
|
||||
|
||||
class Approval(QueryModel, table=True):
|
||||
"""Approval request and decision metadata for gated operations."""
|
||||
|
||||
__tablename__ = "approvals"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Base model mixins and shared SQLModel abstractions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar, Self
|
||||
@@ -8,4 +10,6 @@ from app.db.query_manager import ManagerDescriptor
|
||||
|
||||
|
||||
class QueryModel(SQLModel, table=False):
|
||||
"""Base SQLModel with a shared query manager descriptor."""
|
||||
|
||||
objects: ClassVar[ManagerDescriptor[Self]] = ManagerDescriptor()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Board-group scoped memory entries for shared context."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import JSON, Column
|
||||
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
|
||||
|
||||
|
||||
class BoardGroupMemory(QueryModel, table=True):
|
||||
"""Persisted memory items associated with a board group."""
|
||||
|
||||
__tablename__ = "board_group_memory"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Board group model used to organize boards inside organizations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field
|
||||
@@ -10,6 +12,8 @@ from app.models.tenancy import TenantScoped
|
||||
|
||||
|
||||
class BoardGroup(TenantScoped, table=True):
|
||||
"""Logical grouping container for boards within an organization."""
|
||||
|
||||
__tablename__ = "board_groups"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Board-level memory entries for persistent contextual state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import JSON, Column
|
||||
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
|
||||
|
||||
|
||||
class BoardMemory(QueryModel, table=True):
|
||||
"""Persisted memory item attached directly to a board."""
|
||||
|
||||
__tablename__ = "board_memory"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Board onboarding session model for guided setup state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import JSON, Column
|
||||
@@ -11,13 +13,18 @@ from app.models.base import QueryModel
|
||||
|
||||
|
||||
class BoardOnboardingSession(QueryModel, table=True):
|
||||
"""Persisted onboarding conversation and draft goal data for a board."""
|
||||
|
||||
__tablename__ = "board_onboarding_sessions"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
board_id: UUID = Field(foreign_key="boards.id", index=True)
|
||||
session_key: str
|
||||
status: str = Field(default="active", index=True)
|
||||
messages: list[dict[str, object]] | None = Field(default=None, sa_column=Column(JSON))
|
||||
messages: list[dict[str, object]] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(JSON),
|
||||
)
|
||||
draft_goal: dict[str, object] | None = Field(default=None, sa_column=Column(JSON))
|
||||
created_at: datetime = Field(default_factory=utcnow)
|
||||
updated_at: datetime = Field(default_factory=utcnow)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Board model for organization workspaces and goal configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import JSON, Column
|
||||
@@ -11,6 +13,8 @@ from app.models.tenancy import TenantScoped
|
||||
|
||||
|
||||
class Board(TenantScoped, table=True):
|
||||
"""Primary board entity grouping tasks, agents, and goal metadata."""
|
||||
|
||||
__tablename__ = "boards"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
@@ -18,10 +22,17 @@ class Board(TenantScoped, table=True):
|
||||
name: str
|
||||
slug: str = Field(index=True)
|
||||
gateway_id: UUID | None = Field(default=None, foreign_key="gateways.id", index=True)
|
||||
board_group_id: UUID | None = Field(default=None, foreign_key="board_groups.id", index=True)
|
||||
board_group_id: UUID | None = Field(
|
||||
default=None,
|
||||
foreign_key="board_groups.id",
|
||||
index=True,
|
||||
)
|
||||
board_type: str = Field(default="goal", index=True)
|
||||
objective: str | None = None
|
||||
success_metrics: dict[str, object] | None = Field(default=None, sa_column=Column(JSON))
|
||||
success_metrics: dict[str, object] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(JSON),
|
||||
)
|
||||
target_date: datetime | None = None
|
||||
goal_confirmed: bool = Field(default=False)
|
||||
goal_source: str | None = None
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Gateway model storing organization-level gateway integration metadata."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field
|
||||
@@ -10,6 +12,8 @@ from app.models.base import QueryModel
|
||||
|
||||
|
||||
class Gateway(QueryModel, table=True):
|
||||
"""Configured external gateway endpoint and authentication settings."""
|
||||
|
||||
__tablename__ = "gateways"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Board-level access grants assigned to organization members."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import UniqueConstraint
|
||||
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
|
||||
|
||||
|
||||
class OrganizationBoardAccess(QueryModel, table=True):
|
||||
"""Member-specific board permissions within an organization."""
|
||||
|
||||
__tablename__ = "organization_board_access"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
@@ -21,7 +25,10 @@ class OrganizationBoardAccess(QueryModel, table=True):
|
||||
)
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
organization_member_id: UUID = Field(foreign_key="organization_members.id", index=True)
|
||||
organization_member_id: UUID = Field(
|
||||
foreign_key="organization_members.id",
|
||||
index=True,
|
||||
)
|
||||
board_id: UUID = Field(foreign_key="boards.id", index=True)
|
||||
can_read: bool = Field(default=True)
|
||||
can_write: bool = Field(default=False)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Board access grants attached to pending organization invites."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import UniqueConstraint
|
||||
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
|
||||
|
||||
|
||||
class OrganizationInviteBoardAccess(QueryModel, table=True):
|
||||
"""Invite-specific board permissions applied after invite acceptance."""
|
||||
|
||||
__tablename__ = "organization_invite_board_access"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
@@ -21,7 +25,10 @@ class OrganizationInviteBoardAccess(QueryModel, table=True):
|
||||
)
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
organization_invite_id: UUID = Field(foreign_key="organization_invites.id", index=True)
|
||||
organization_invite_id: UUID = Field(
|
||||
foreign_key="organization_invites.id",
|
||||
index=True,
|
||||
)
|
||||
board_id: UUID = Field(foreign_key="boards.id", index=True)
|
||||
can_read: bool = Field(default=True)
|
||||
can_write: bool = Field(default=False)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Organization invite model for email-based tenant membership flow."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import UniqueConstraint
|
||||
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
|
||||
|
||||
|
||||
class OrganizationInvite(QueryModel, table=True):
|
||||
"""Invitation record granting prospective organization access."""
|
||||
|
||||
__tablename__ = "organization_invites"
|
||||
__table_args__ = (UniqueConstraint("token", name="uq_org_invites_token"),)
|
||||
|
||||
@@ -21,8 +25,16 @@ class OrganizationInvite(QueryModel, table=True):
|
||||
role: str = Field(default="member", index=True)
|
||||
all_boards_read: bool = Field(default=False)
|
||||
all_boards_write: bool = Field(default=False)
|
||||
created_by_user_id: UUID | None = Field(default=None, foreign_key="users.id", index=True)
|
||||
accepted_by_user_id: UUID | None = Field(default=None, foreign_key="users.id", index=True)
|
||||
created_by_user_id: UUID | None = Field(
|
||||
default=None,
|
||||
foreign_key="users.id",
|
||||
index=True,
|
||||
)
|
||||
accepted_by_user_id: UUID | None = Field(
|
||||
default=None,
|
||||
foreign_key="users.id",
|
||||
index=True,
|
||||
)
|
||||
accepted_at: datetime | None = None
|
||||
created_at: datetime = Field(default_factory=utcnow)
|
||||
updated_at: datetime = Field(default_factory=utcnow)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Organization membership model with role and board-access flags."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import UniqueConstraint
|
||||
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
|
||||
|
||||
|
||||
class OrganizationMember(QueryModel, table=True):
|
||||
"""Membership row linking a user to an organization and permissions."""
|
||||
|
||||
__tablename__ = "organization_members"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Organization model representing top-level tenant entities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import UniqueConstraint
|
||||
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
|
||||
|
||||
|
||||
class Organization(QueryModel, table=True):
|
||||
"""Top-level organization tenant record."""
|
||||
|
||||
__tablename__ = "organizations"
|
||||
__table_args__ = (UniqueConstraint("name", name="uq_organizations_name"),)
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Task dependency edge model for board-local dependency graphs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import CheckConstraint, UniqueConstraint
|
||||
@@ -11,6 +13,8 @@ from app.models.tenancy import TenantScoped
|
||||
|
||||
|
||||
class TaskDependency(TenantScoped, table=True):
|
||||
"""Directed dependency edge between two tasks in the same board."""
|
||||
|
||||
__tablename__ = "task_dependencies"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Task fingerprint model for duplicate/task-linking operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field
|
||||
@@ -10,6 +12,8 @@ from app.models.base import QueryModel
|
||||
|
||||
|
||||
class TaskFingerprint(QueryModel, table=True):
|
||||
"""Hashed task-content fingerprint associated with a board and task."""
|
||||
|
||||
__tablename__ = "task_fingerprints"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Task model representing board work items and execution metadata."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field
|
||||
@@ -10,6 +12,8 @@ from app.models.tenancy import TenantScoped
|
||||
|
||||
|
||||
class Task(TenantScoped, table=True):
|
||||
"""Board-scoped task entity with ownership, status, and timing fields."""
|
||||
|
||||
__tablename__ = "tasks"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
@@ -22,8 +26,16 @@ class Task(TenantScoped, table=True):
|
||||
due_at: datetime | None = None
|
||||
in_progress_at: datetime | None = None
|
||||
|
||||
created_by_user_id: UUID | None = Field(default=None, foreign_key="users.id", index=True)
|
||||
assigned_agent_id: UUID | None = Field(default=None, foreign_key="agents.id", index=True)
|
||||
created_by_user_id: UUID | None = Field(
|
||||
default=None,
|
||||
foreign_key="users.id",
|
||||
index=True,
|
||||
)
|
||||
assigned_agent_id: UUID | None = Field(
|
||||
default=None,
|
||||
foreign_key="agents.id",
|
||||
index=True,
|
||||
)
|
||||
auto_created: bool = Field(default=False)
|
||||
auto_reason: str | None = None
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""Shared tenancy-scoped model base classes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.models.base import QueryModel
|
||||
|
||||
|
||||
class TenantScoped(QueryModel, table=False):
|
||||
pass
|
||||
"""Base class for models constrained to a tenant/organization scope."""
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""User model storing identity and profile preferences."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID, uuid4
|
||||
@@ -8,6 +10,8 @@ from app.models.base import QueryModel
|
||||
|
||||
|
||||
class User(QueryModel, table=True):
|
||||
"""Application user account and profile attributes."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
@@ -21,5 +25,7 @@ class User(QueryModel, table=True):
|
||||
context: str | None = None
|
||||
is_super_admin: bool = Field(default=False)
|
||||
active_organization_id: UUID | None = Field(
|
||||
default=None, foreign_key="organizations.id", index=True
|
||||
default=None,
|
||||
foreign_key="organizations.id",
|
||||
index=True,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Public schema exports shared across API route modules."""
|
||||
|
||||
from app.schemas.activity_events import ActivityEventRead
|
||||
from app.schemas.agents import AgentCreate, AgentRead, AgentUpdate
|
||||
from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalUpdate
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
"""Response schemas for activity events and task-comment feed items."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID # noqa: TCH003
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
|
||||
class ActivityEventRead(SQLModel):
|
||||
"""Serialized activity event payload returned by activity endpoints."""
|
||||
|
||||
id: UUID
|
||||
event_type: str
|
||||
message: str | None
|
||||
@@ -16,6 +20,8 @@ class ActivityEventRead(SQLModel):
|
||||
|
||||
|
||||
class ActivityTaskCommentFeedItemRead(SQLModel):
|
||||
"""Denormalized task-comment feed item enriched with task and board fields."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
message: str | None
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
"""Schemas for approval create/update/read API payloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from typing import Literal, Self
|
||||
from uuid import UUID
|
||||
from uuid import UUID # noqa: TCH003
|
||||
|
||||
from pydantic import model_validator
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
ApprovalStatus = Literal["pending", "approved", "rejected"]
|
||||
STATUS_REQUIRED_ERROR = "status is required"
|
||||
|
||||
|
||||
class ApprovalBase(SQLModel):
|
||||
"""Shared approval fields used across create/read payloads."""
|
||||
|
||||
action_type: str
|
||||
task_id: UUID | None = None
|
||||
payload: dict[str, object] | None = None
|
||||
@@ -20,20 +25,27 @@ class ApprovalBase(SQLModel):
|
||||
|
||||
|
||||
class ApprovalCreate(ApprovalBase):
|
||||
"""Payload for creating a new approval request."""
|
||||
|
||||
agent_id: UUID | None = None
|
||||
|
||||
|
||||
class ApprovalUpdate(SQLModel):
|
||||
"""Payload for mutating approval status."""
|
||||
|
||||
status: ApprovalStatus | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_status(self) -> Self:
|
||||
"""Ensure explicitly provided `status` is not null."""
|
||||
if "status" in self.model_fields_set and self.status is None:
|
||||
raise ValueError("status is required")
|
||||
raise ValueError(STATUS_REQUIRED_ERROR)
|
||||
return self
|
||||
|
||||
|
||||
class ApprovalRead(ApprovalBase):
|
||||
"""Approval payload returned from read endpoints."""
|
||||
|
||||
id: UUID
|
||||
board_id: UUID
|
||||
agent_id: UUID | None = None
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
"""Schemas for applying heartbeat settings to board-group agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
from uuid import UUID # noqa: TCH003
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
|
||||
class BoardGroupHeartbeatApply(SQLModel):
|
||||
# Heartbeat cadence string understood by the OpenClaw gateway (e.g. "2m", "10m", "30m").
|
||||
"""Request payload for heartbeat policy updates."""
|
||||
|
||||
# Heartbeat cadence string understood by the OpenClaw gateway
|
||||
# (e.g. "2m", "10m", "30m").
|
||||
every: str
|
||||
# Optional heartbeat target (most deployments use "none").
|
||||
target: str | None = None
|
||||
@@ -15,6 +20,8 @@ class BoardGroupHeartbeatApply(SQLModel):
|
||||
|
||||
|
||||
class BoardGroupHeartbeatApplyResult(SQLModel):
|
||||
"""Result payload describing agents updated by a heartbeat request."""
|
||||
|
||||
board_group_id: UUID
|
||||
requested: dict[str, Any]
|
||||
updated_agent_ids: list[UUID]
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
"""Schemas for board-group memory create/read API payloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID # noqa: TCH003
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from app.schemas.common import NonEmptyStr
|
||||
from app.schemas.common import NonEmptyStr # noqa: TCH001
|
||||
|
||||
|
||||
class BoardGroupMemoryCreate(SQLModel):
|
||||
"""Payload for creating a board-group memory entry."""
|
||||
|
||||
# For writes, reject blank/whitespace-only content.
|
||||
content: NonEmptyStr
|
||||
tags: list[str] | None = None
|
||||
@@ -16,9 +20,12 @@ class BoardGroupMemoryCreate(SQLModel):
|
||||
|
||||
|
||||
class BoardGroupMemoryRead(SQLModel):
|
||||
"""Serialized board-group memory entry returned from read endpoints."""
|
||||
|
||||
id: UUID
|
||||
board_group_id: UUID
|
||||
# For reads, allow legacy rows that may have empty content (avoid response validation 500s).
|
||||
# For reads, allow legacy rows that may have empty content
|
||||
# (avoid response validation 500s).
|
||||
content: str
|
||||
tags: list[str] | None = None
|
||||
source: str | None = None
|
||||
|
||||
@@ -1,28 +1,36 @@
|
||||
"""Schemas for board-group create/update/read API operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID # noqa: TCH003
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
|
||||
class BoardGroupBase(SQLModel):
|
||||
"""Shared board-group fields for create/read operations."""
|
||||
|
||||
name: str
|
||||
slug: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class BoardGroupCreate(BoardGroupBase):
|
||||
pass
|
||||
"""Payload for creating a board group."""
|
||||
|
||||
|
||||
class BoardGroupUpdate(SQLModel):
|
||||
"""Payload for partial board-group updates."""
|
||||
|
||||
name: str | None = None
|
||||
slug: str | None = None
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class BoardGroupRead(BoardGroupBase):
|
||||
"""Board-group payload returned from read endpoints."""
|
||||
|
||||
id: UUID
|
||||
organization_id: UUID
|
||||
created_at: datetime
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
"""Schemas for board memory create/read API payloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID # noqa: TCH003
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from app.schemas.common import NonEmptyStr
|
||||
from app.schemas.common import NonEmptyStr # noqa: TCH001
|
||||
|
||||
|
||||
class BoardMemoryCreate(SQLModel):
|
||||
"""Payload for creating a board memory entry."""
|
||||
|
||||
# For writes, reject blank/whitespace-only content.
|
||||
content: NonEmptyStr
|
||||
tags: list[str] | None = None
|
||||
@@ -16,9 +20,12 @@ class BoardMemoryCreate(SQLModel):
|
||||
|
||||
|
||||
class BoardMemoryRead(SQLModel):
|
||||
"""Serialized board memory entry returned from read endpoints."""
|
||||
|
||||
id: UUID
|
||||
board_id: UUID
|
||||
# For reads, allow legacy rows that may have empty content (avoid response validation 500s).
|
||||
# For reads, allow legacy rows that may have empty content
|
||||
# (avoid response validation 500s).
|
||||
content: str
|
||||
tags: list[str] | None = None
|
||||
source: str | None = None
|
||||
|
||||
@@ -1,14 +1,23 @@
|
||||
"""Schemas for board create/update/read API operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from typing import Self
|
||||
from uuid import UUID
|
||||
from uuid import UUID # noqa: TCH003
|
||||
|
||||
from pydantic import model_validator
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
_ERR_GOAL_FIELDS_REQUIRED = (
|
||||
"Confirmed goal boards require objective and success_metrics"
|
||||
)
|
||||
_ERR_GATEWAY_REQUIRED = "gateway_id is required"
|
||||
|
||||
|
||||
class BoardBase(SQLModel):
|
||||
"""Shared board fields used across create and read payloads."""
|
||||
|
||||
name: str
|
||||
slug: str
|
||||
gateway_id: UUID | None = None
|
||||
@@ -22,17 +31,25 @@ class BoardBase(SQLModel):
|
||||
|
||||
|
||||
class BoardCreate(BoardBase):
|
||||
"""Payload for creating a board."""
|
||||
|
||||
gateway_id: UUID
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_goal_fields(self) -> Self:
|
||||
if self.board_type == "goal" and self.goal_confirmed:
|
||||
if not self.objective or not self.success_metrics:
|
||||
raise ValueError("Confirmed goal boards require objective and success_metrics")
|
||||
"""Require goal details when creating a confirmed goal board."""
|
||||
if (
|
||||
self.board_type == "goal"
|
||||
and self.goal_confirmed
|
||||
and (not self.objective or not self.success_metrics)
|
||||
):
|
||||
raise ValueError(_ERR_GOAL_FIELDS_REQUIRED)
|
||||
return self
|
||||
|
||||
|
||||
class BoardUpdate(SQLModel):
|
||||
"""Payload for partial board updates."""
|
||||
|
||||
name: str | None = None
|
||||
slug: str | None = None
|
||||
gateway_id: UUID | None = None
|
||||
@@ -46,13 +63,16 @@ class BoardUpdate(SQLModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_gateway_id(self) -> Self:
|
||||
"""Reject explicit null gateway IDs in patch payloads."""
|
||||
# Treat explicit null like "unset" is invalid for patch updates.
|
||||
if "gateway_id" in self.model_fields_set and self.gateway_id is None:
|
||||
raise ValueError("gateway_id is required")
|
||||
raise ValueError(_ERR_GATEWAY_REQUIRED)
|
||||
return self
|
||||
|
||||
|
||||
class BoardRead(BoardBase):
|
||||
"""Board payload returned from read endpoints."""
|
||||
|
||||
id: UUID
|
||||
organization_id: UUID
|
||||
created_at: datetime
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Common reusable schema primitives and simple API response envelopes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
@@ -5,9 +7,12 @@ from typing import Annotated
|
||||
from pydantic import StringConstraints
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
# Reusable string type for request payloads where blank/whitespace-only values are invalid.
|
||||
# Reusable string type for request payloads where blank/whitespace-only values
|
||||
# are invalid.
|
||||
NonEmptyStr = Annotated[str, StringConstraints(strip_whitespace=True, min_length=1)]
|
||||
|
||||
|
||||
class OkResponse(SQLModel):
|
||||
"""Standard success response payload."""
|
||||
|
||||
ok: bool = True
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
"""Structured error payload schemas used by API responses."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class BlockedTaskDetail(SQLModel):
|
||||
"""Error detail payload listing blocking dependency task identifiers."""
|
||||
|
||||
message: str
|
||||
blocked_by_task_ids: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BlockedTaskError(SQLModel):
|
||||
"""Top-level blocked-task error response envelope."""
|
||||
|
||||
detail: BlockedTaskDetail
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
"""Schemas for gateway passthrough API request and response payloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from app.schemas.common import NonEmptyStr
|
||||
from app.schemas.common import NonEmptyStr # noqa: TCH001
|
||||
|
||||
|
||||
class GatewaySessionMessageRequest(SQLModel):
|
||||
"""Request payload for sending a message into a gateway session."""
|
||||
|
||||
content: NonEmptyStr
|
||||
|
||||
|
||||
class GatewayResolveQuery(SQLModel):
|
||||
"""Query parameters used to resolve which gateway to target."""
|
||||
|
||||
board_id: str | None = None
|
||||
gateway_url: str | None = None
|
||||
gateway_token: str | None = None
|
||||
@@ -17,6 +23,8 @@ class GatewayResolveQuery(SQLModel):
|
||||
|
||||
|
||||
class GatewaysStatusResponse(SQLModel):
|
||||
"""Aggregated gateway status response including session metadata."""
|
||||
|
||||
connected: bool
|
||||
gateway_url: str
|
||||
sessions_count: int | None = None
|
||||
@@ -28,20 +36,28 @@ class GatewaysStatusResponse(SQLModel):
|
||||
|
||||
|
||||
class GatewaySessionsResponse(SQLModel):
|
||||
"""Gateway sessions list response payload."""
|
||||
|
||||
sessions: list[object]
|
||||
main_session_key: str | None = None
|
||||
main_session: object | None = None
|
||||
|
||||
|
||||
class GatewaySessionResponse(SQLModel):
|
||||
"""Single gateway session response payload."""
|
||||
|
||||
session: object
|
||||
|
||||
|
||||
class GatewaySessionHistoryResponse(SQLModel):
|
||||
"""Gateway session history response payload."""
|
||||
|
||||
history: list[object]
|
||||
|
||||
|
||||
class GatewayCommandsResponse(SQLModel):
|
||||
"""Gateway command catalog and protocol metadata."""
|
||||
|
||||
protocol_version: int
|
||||
methods: list[str]
|
||||
events: list[str]
|
||||
|
||||
@@ -1,24 +1,38 @@
|
||||
"""Schemas for gateway-main and lead-agent coordination endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
from uuid import UUID # noqa: TCH003
|
||||
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from app.schemas.common import NonEmptyStr
|
||||
from app.schemas.common import NonEmptyStr # noqa: TCH001
|
||||
|
||||
|
||||
def _lead_reply_tags() -> list[str]:
|
||||
return ["gateway_main", "lead_reply"]
|
||||
|
||||
|
||||
def _user_reply_tags() -> list[str]:
|
||||
return ["gateway_main", "user_reply"]
|
||||
|
||||
|
||||
class GatewayLeadMessageRequest(SQLModel):
|
||||
"""Request payload for sending a message to a board lead agent."""
|
||||
|
||||
kind: Literal["question", "handoff"] = "question"
|
||||
correlation_id: str | None = None
|
||||
content: NonEmptyStr
|
||||
|
||||
# How the lead should reply (defaults are interpreted by templates).
|
||||
reply_tags: list[str] = Field(default_factory=lambda: ["gateway_main", "lead_reply"])
|
||||
reply_tags: list[str] = Field(default_factory=_lead_reply_tags)
|
||||
reply_source: str | None = "lead_to_gateway_main"
|
||||
|
||||
|
||||
class GatewayLeadMessageResponse(SQLModel):
|
||||
"""Response payload for a lead-message dispatch attempt."""
|
||||
|
||||
ok: bool = True
|
||||
board_id: UUID
|
||||
lead_agent_id: UUID | None = None
|
||||
@@ -27,15 +41,19 @@ class GatewayLeadMessageResponse(SQLModel):
|
||||
|
||||
|
||||
class GatewayLeadBroadcastRequest(SQLModel):
|
||||
"""Request payload for broadcasting a message to multiple board leads."""
|
||||
|
||||
kind: Literal["question", "handoff"] = "question"
|
||||
correlation_id: str | None = None
|
||||
content: NonEmptyStr
|
||||
board_ids: list[UUID] | None = None
|
||||
reply_tags: list[str] = Field(default_factory=lambda: ["gateway_main", "lead_reply"])
|
||||
reply_tags: list[str] = Field(default_factory=_lead_reply_tags)
|
||||
reply_source: str | None = "lead_to_gateway_main"
|
||||
|
||||
|
||||
class GatewayLeadBroadcastBoardResult(SQLModel):
|
||||
"""Per-board result entry for a lead broadcast operation."""
|
||||
|
||||
board_id: UUID
|
||||
lead_agent_id: UUID | None = None
|
||||
lead_agent_name: str | None = None
|
||||
@@ -44,6 +62,8 @@ class GatewayLeadBroadcastBoardResult(SQLModel):
|
||||
|
||||
|
||||
class GatewayLeadBroadcastResponse(SQLModel):
|
||||
"""Aggregate response for a lead broadcast operation."""
|
||||
|
||||
ok: bool = True
|
||||
sent: int = 0
|
||||
failed: int = 0
|
||||
@@ -51,16 +71,21 @@ class GatewayLeadBroadcastResponse(SQLModel):
|
||||
|
||||
|
||||
class GatewayMainAskUserRequest(SQLModel):
|
||||
"""Request payload for asking the end user via a main gateway agent."""
|
||||
|
||||
correlation_id: str | None = None
|
||||
content: NonEmptyStr
|
||||
preferred_channel: str | None = None
|
||||
|
||||
# How the main agent should reply back into Mission Control (defaults interpreted by templates).
|
||||
reply_tags: list[str] = Field(default_factory=lambda: ["gateway_main", "user_reply"])
|
||||
# How the main agent should reply back into Mission Control
|
||||
# (defaults interpreted by templates).
|
||||
reply_tags: list[str] = Field(default_factory=_user_reply_tags)
|
||||
reply_source: str | None = "user_via_gateway_main"
|
||||
|
||||
|
||||
class GatewayMainAskUserResponse(SQLModel):
|
||||
"""Response payload for user-question dispatch via gateway main agent."""
|
||||
|
||||
ok: bool = True
|
||||
board_id: UUID
|
||||
main_agent_id: UUID | None = None
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
"""Schemas for gateway CRUD and template-sync API payloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID # noqa: TCH003
|
||||
|
||||
from pydantic import field_validator
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class GatewayBase(SQLModel):
|
||||
"""Shared gateway fields used across create/read payloads."""
|
||||
|
||||
name: str
|
||||
url: str
|
||||
main_session_key: str
|
||||
@@ -16,11 +19,14 @@ class GatewayBase(SQLModel):
|
||||
|
||||
|
||||
class GatewayCreate(GatewayBase):
|
||||
"""Payload for creating a gateway configuration."""
|
||||
|
||||
token: str | None = None
|
||||
|
||||
@field_validator("token", mode="before")
|
||||
@classmethod
|
||||
def normalize_token(cls, value: Any) -> Any:
|
||||
def normalize_token(cls, value: object) -> str | None | object:
|
||||
"""Normalize empty/whitespace tokens to `None`."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
@@ -30,6 +36,8 @@ class GatewayCreate(GatewayBase):
|
||||
|
||||
|
||||
class GatewayUpdate(SQLModel):
|
||||
"""Payload for partial gateway updates."""
|
||||
|
||||
name: str | None = None
|
||||
url: str | None = None
|
||||
token: str | None = None
|
||||
@@ -38,7 +46,8 @@ class GatewayUpdate(SQLModel):
|
||||
|
||||
@field_validator("token", mode="before")
|
||||
@classmethod
|
||||
def normalize_token(cls, value: Any) -> Any:
|
||||
def normalize_token(cls, value: object) -> str | None | object:
|
||||
"""Normalize empty/whitespace tokens to `None`."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
@@ -48,6 +57,8 @@ class GatewayUpdate(SQLModel):
|
||||
|
||||
|
||||
class GatewayRead(GatewayBase):
|
||||
"""Gateway payload returned from read endpoints."""
|
||||
|
||||
id: UUID
|
||||
organization_id: UUID
|
||||
token: str | None = None
|
||||
@@ -56,6 +67,8 @@ class GatewayRead(GatewayBase):
|
||||
|
||||
|
||||
class GatewayTemplatesSyncError(SQLModel):
|
||||
"""Per-agent error entry from a gateway template sync operation."""
|
||||
|
||||
agent_id: UUID | None = None
|
||||
agent_name: str | None = None
|
||||
board_id: UUID | None = None
|
||||
@@ -63,6 +76,8 @@ class GatewayTemplatesSyncError(SQLModel):
|
||||
|
||||
|
||||
class GatewayTemplatesSyncResult(SQLModel):
|
||||
"""Summary payload returned by gateway template sync endpoints."""
|
||||
|
||||
gateway_id: UUID
|
||||
include_main: bool
|
||||
reset_sessions: bool
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
"""Dashboard metrics schemas for KPI and time-series API responses."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from typing import Literal
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
|
||||
class DashboardSeriesPoint(SQLModel):
|
||||
"""Single numeric time-series point."""
|
||||
|
||||
period: datetime
|
||||
value: float
|
||||
|
||||
|
||||
class DashboardWipPoint(SQLModel):
|
||||
"""Work-in-progress point split by task status buckets."""
|
||||
|
||||
period: datetime
|
||||
inbox: int
|
||||
in_progress: int
|
||||
@@ -19,28 +25,38 @@ class DashboardWipPoint(SQLModel):
|
||||
|
||||
|
||||
class DashboardRangeSeries(SQLModel):
|
||||
"""Series payload for a single range/bucket combination."""
|
||||
|
||||
range: Literal["24h", "7d"]
|
||||
bucket: Literal["hour", "day"]
|
||||
points: list[DashboardSeriesPoint]
|
||||
|
||||
|
||||
class DashboardWipRangeSeries(SQLModel):
|
||||
"""WIP series payload for a single range/bucket combination."""
|
||||
|
||||
range: Literal["24h", "7d"]
|
||||
bucket: Literal["hour", "day"]
|
||||
points: list[DashboardWipPoint]
|
||||
|
||||
|
||||
class DashboardSeriesSet(SQLModel):
|
||||
"""Primary vs comparison pair for generic series metrics."""
|
||||
|
||||
primary: DashboardRangeSeries
|
||||
comparison: DashboardRangeSeries
|
||||
|
||||
|
||||
class DashboardWipSeriesSet(SQLModel):
|
||||
"""Primary vs comparison pair for WIP status series metrics."""
|
||||
|
||||
primary: DashboardWipRangeSeries
|
||||
comparison: DashboardWipRangeSeries
|
||||
|
||||
|
||||
class DashboardKpis(SQLModel):
|
||||
"""Topline dashboard KPI summary values."""
|
||||
|
||||
active_agents: int
|
||||
tasks_in_progress: int
|
||||
error_rate_pct: float
|
||||
@@ -48,6 +64,8 @@ class DashboardKpis(SQLModel):
|
||||
|
||||
|
||||
class DashboardMetrics(SQLModel):
|
||||
"""Complete dashboard metrics response payload."""
|
||||
|
||||
range: Literal["24h", "7d"]
|
||||
generated_at: datetime
|
||||
kpis: DashboardKpis
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
"""Schemas for organization, membership, and invite API payloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID # noqa: TCH003
|
||||
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class OrganizationRead(SQLModel):
|
||||
"""Organization payload returned by read endpoints."""
|
||||
|
||||
id: UUID
|
||||
name: str
|
||||
created_at: datetime
|
||||
@@ -14,14 +18,20 @@ class OrganizationRead(SQLModel):
|
||||
|
||||
|
||||
class OrganizationCreate(SQLModel):
|
||||
"""Payload for creating a new organization."""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
class OrganizationActiveUpdate(SQLModel):
|
||||
"""Payload for switching the active organization context."""
|
||||
|
||||
organization_id: UUID
|
||||
|
||||
|
||||
class OrganizationListItem(SQLModel):
|
||||
"""Organization list row for current user memberships."""
|
||||
|
||||
id: UUID
|
||||
name: str
|
||||
role: str
|
||||
@@ -29,6 +39,8 @@ class OrganizationListItem(SQLModel):
|
||||
|
||||
|
||||
class OrganizationUserRead(SQLModel):
|
||||
"""Embedded user fields included in organization member payloads."""
|
||||
|
||||
id: UUID
|
||||
email: str | None = None
|
||||
name: str | None = None
|
||||
@@ -36,6 +48,8 @@ class OrganizationUserRead(SQLModel):
|
||||
|
||||
|
||||
class OrganizationMemberRead(SQLModel):
|
||||
"""Organization member payload including board-level access overrides."""
|
||||
|
||||
id: UUID
|
||||
organization_id: UUID
|
||||
user_id: UUID
|
||||
@@ -49,16 +63,22 @@ class OrganizationMemberRead(SQLModel):
|
||||
|
||||
|
||||
class OrganizationMemberUpdate(SQLModel):
|
||||
"""Payload for partial updates to organization member role."""
|
||||
|
||||
role: str | None = None
|
||||
|
||||
|
||||
class OrganizationBoardAccessSpec(SQLModel):
|
||||
"""Board access specification used in member/invite mutation payloads."""
|
||||
|
||||
board_id: UUID
|
||||
can_read: bool = True
|
||||
can_write: bool = False
|
||||
|
||||
|
||||
class OrganizationBoardAccessRead(SQLModel):
|
||||
"""Board access payload returned from read endpoints."""
|
||||
|
||||
id: UUID
|
||||
board_id: UUID
|
||||
can_read: bool
|
||||
@@ -68,12 +88,16 @@ class OrganizationBoardAccessRead(SQLModel):
|
||||
|
||||
|
||||
class OrganizationMemberAccessUpdate(SQLModel):
|
||||
"""Payload for replacing organization member access permissions."""
|
||||
|
||||
all_boards_read: bool = False
|
||||
all_boards_write: bool = False
|
||||
board_access: list[OrganizationBoardAccessSpec] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OrganizationInviteCreate(SQLModel):
|
||||
"""Payload for creating an organization invite."""
|
||||
|
||||
invited_email: str
|
||||
role: str = "member"
|
||||
all_boards_read: bool = False
|
||||
@@ -82,6 +106,8 @@ class OrganizationInviteCreate(SQLModel):
|
||||
|
||||
|
||||
class OrganizationInviteRead(SQLModel):
|
||||
"""Organization invite payload returned from read endpoints."""
|
||||
|
||||
id: UUID
|
||||
organization_id: UUID
|
||||
invited_email: str
|
||||
@@ -97,4 +123,6 @@ class OrganizationInviteRead(SQLModel):
|
||||
|
||||
|
||||
class OrganizationInviteAccept(SQLModel):
|
||||
"""Payload for accepting an organization invite token."""
|
||||
|
||||
token: str
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Shared pagination response type aliases used by API routes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypeVar
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
"""Schemas for souls-directory search and markdown fetch responses."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SoulsDirectorySoulRef(BaseModel):
|
||||
"""Reference metadata for a soul entry in the directory index."""
|
||||
|
||||
handle: str
|
||||
slug: str
|
||||
page_url: str
|
||||
@@ -11,10 +15,14 @@ class SoulsDirectorySoulRef(BaseModel):
|
||||
|
||||
|
||||
class SoulsDirectorySearchResponse(BaseModel):
|
||||
"""Response wrapper for directory search results."""
|
||||
|
||||
items: list[SoulsDirectorySoulRef]
|
||||
|
||||
|
||||
class SoulsDirectoryMarkdownResponse(BaseModel):
|
||||
"""Response payload containing rendered markdown for a soul."""
|
||||
|
||||
handle: str
|
||||
slug: str
|
||||
content: str
|
||||
|
||||
@@ -1,18 +1,23 @@
|
||||
"""Schemas for task CRUD and task comment API payloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, Self
|
||||
from uuid import UUID
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from typing import Literal, Self
|
||||
from uuid import UUID # noqa: TCH003
|
||||
|
||||
from pydantic import field_validator, model_validator
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from app.schemas.common import NonEmptyStr
|
||||
from app.schemas.common import NonEmptyStr # noqa: TCH001
|
||||
|
||||
TaskStatus = Literal["inbox", "in_progress", "review", "done"]
|
||||
STATUS_REQUIRED_ERROR = "status is required"
|
||||
|
||||
|
||||
class TaskBase(SQLModel):
|
||||
"""Shared task fields used by task create/read payloads."""
|
||||
|
||||
title: str
|
||||
description: str | None = None
|
||||
status: TaskStatus = "inbox"
|
||||
@@ -23,10 +28,14 @@ class TaskBase(SQLModel):
|
||||
|
||||
|
||||
class TaskCreate(TaskBase):
|
||||
"""Payload for creating a task."""
|
||||
|
||||
created_by_user_id: UUID | None = None
|
||||
|
||||
|
||||
class TaskUpdate(SQLModel):
|
||||
"""Payload for partial task updates."""
|
||||
|
||||
title: str | None = None
|
||||
description: str | None = None
|
||||
status: TaskStatus | None = None
|
||||
@@ -38,7 +47,8 @@ class TaskUpdate(SQLModel):
|
||||
|
||||
@field_validator("comment", mode="before")
|
||||
@classmethod
|
||||
def normalize_comment(cls, value: Any) -> Any:
|
||||
def normalize_comment(cls, value: object) -> object | None:
|
||||
"""Normalize blank comment strings to `None`."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str) and not value.strip():
|
||||
@@ -47,12 +57,15 @@ class TaskUpdate(SQLModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_status(self) -> Self:
|
||||
"""Ensure explicitly supplied status is not null."""
|
||||
if "status" in self.model_fields_set and self.status is None:
|
||||
raise ValueError("status is required")
|
||||
raise ValueError(STATUS_REQUIRED_ERROR)
|
||||
return self
|
||||
|
||||
|
||||
class TaskRead(TaskBase):
|
||||
"""Task payload returned from read endpoints."""
|
||||
|
||||
id: UUID
|
||||
board_id: UUID | None
|
||||
created_by_user_id: UUID | None
|
||||
@@ -64,10 +77,14 @@ class TaskRead(TaskBase):
|
||||
|
||||
|
||||
class TaskCommentCreate(SQLModel):
|
||||
"""Payload for creating a task comment."""
|
||||
|
||||
message: NonEmptyStr
|
||||
|
||||
|
||||
class TaskCommentRead(SQLModel):
|
||||
"""Task comment payload returned from read endpoints."""
|
||||
|
||||
id: UUID
|
||||
message: str | None
|
||||
agent_id: UUID | None
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
"""User API schemas for create, update, and read operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
from uuid import UUID # noqa: TCH003
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
|
||||
class UserBase(SQLModel):
|
||||
"""Common user profile fields shared across user payload schemas."""
|
||||
|
||||
clerk_user_id: str
|
||||
email: str | None = None
|
||||
name: str | None = None
|
||||
@@ -17,10 +21,12 @@ class UserBase(SQLModel):
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
pass
|
||||
"""Payload used to create a user record."""
|
||||
|
||||
|
||||
class UserUpdate(SQLModel):
|
||||
"""Payload for partial user profile updates."""
|
||||
|
||||
name: str | None = None
|
||||
preferred_name: str | None = None
|
||||
pronouns: str | None = None
|
||||
@@ -30,5 +36,7 @@ class UserUpdate(SQLModel):
|
||||
|
||||
|
||||
class UserRead(UserBase):
|
||||
"""Full user payload returned by API responses."""
|
||||
|
||||
id: UUID
|
||||
is_super_admin: bool
|
||||
|
||||
@@ -1,25 +1,31 @@
|
||||
"""Composite read models assembled for board and board-group views."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
from datetime import datetime # noqa: TCH003
|
||||
from uuid import UUID # noqa: TCH003
|
||||
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from app.schemas.agents import AgentRead
|
||||
from app.schemas.approvals import ApprovalRead
|
||||
from app.schemas.board_groups import BoardGroupRead
|
||||
from app.schemas.board_memory import BoardMemoryRead
|
||||
from app.schemas.boards import BoardRead
|
||||
from app.schemas.agents import AgentRead # noqa: TCH001
|
||||
from app.schemas.approvals import ApprovalRead # noqa: TCH001
|
||||
from app.schemas.board_groups import BoardGroupRead # noqa: TCH001
|
||||
from app.schemas.board_memory import BoardMemoryRead # noqa: TCH001
|
||||
from app.schemas.boards import BoardRead # noqa: TCH001
|
||||
from app.schemas.tasks import TaskRead
|
||||
|
||||
|
||||
class TaskCardRead(TaskRead):
|
||||
"""Task read model enriched with assignee and approval counters."""
|
||||
|
||||
assignee: str | None = None
|
||||
approvals_count: int = 0
|
||||
approvals_pending_count: int = 0
|
||||
|
||||
|
||||
class BoardSnapshot(SQLModel):
|
||||
"""Aggregated board payload used by board snapshot endpoints."""
|
||||
|
||||
board: BoardRead
|
||||
tasks: list[TaskCardRead]
|
||||
agents: list[AgentRead]
|
||||
@@ -29,6 +35,8 @@ class BoardSnapshot(SQLModel):
|
||||
|
||||
|
||||
class BoardGroupTaskSummary(SQLModel):
|
||||
"""Task summary row used inside board-group snapshot responses."""
|
||||
|
||||
id: UUID
|
||||
board_id: UUID
|
||||
board_name: str
|
||||
@@ -44,11 +52,15 @@ class BoardGroupTaskSummary(SQLModel):
|
||||
|
||||
|
||||
class BoardGroupBoardSnapshot(SQLModel):
|
||||
"""Board-level rollup embedded within a board-group snapshot."""
|
||||
|
||||
board: BoardRead
|
||||
task_counts: dict[str, int] = Field(default_factory=dict)
|
||||
tasks: list[BoardGroupTaskSummary] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BoardGroupSnapshot(SQLModel):
|
||||
"""Top-level board-group snapshot response payload."""
|
||||
|
||||
group: BoardGroupRead | None = None
|
||||
boards: list[BoardGroupBoardSnapshot] = Field(default_factory=list)
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Business logic services for backend domain operations."""
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
"""Utilities for recording normalized activity events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.activity_events import ActivityEvent
|
||||
|
||||
@@ -15,6 +20,7 @@ def record_activity(
|
||||
agent_id: UUID | None = None,
|
||||
task_id: UUID | None = None,
|
||||
) -> ActivityEvent:
|
||||
"""Create and attach an activity event row to the current DB session."""
|
||||
event = ActivityEvent(
|
||||
event_type=event_type,
|
||||
message=message,
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
"""Access control helpers for admin-only operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.core.auth import AuthContext
|
||||
if TYPE_CHECKING:
|
||||
from app.core.auth import AuthContext
|
||||
|
||||
|
||||
def require_admin(auth: AuthContext) -> None:
|
||||
"""Raise HTTP 403 unless the authenticated actor is a user admin."""
|
||||
if auth.actor_type != "user" or auth.user is None:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
|
||||
@@ -1,21 +1,31 @@
|
||||
"""Gateway-facing agent provisioning and cleanup helpers."""
|
||||
# ruff: noqa: EM101, TRY003
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, StrictUndefined, select_autoescape
|
||||
|
||||
from app.core.config import settings
|
||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, openclaw_call
|
||||
from app.models.agents import Agent
|
||||
from app.models.boards import Board
|
||||
from app.models.gateways import Gateway
|
||||
from app.models.users import User
|
||||
from app.integrations.openclaw_gateway import (
|
||||
OpenClawGatewayError,
|
||||
ensure_session,
|
||||
openclaw_call,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.agents import Agent
|
||||
from app.models.boards import Board
|
||||
from app.models.gateways import Gateway
|
||||
from app.models.users import User
|
||||
|
||||
DEFAULT_HEARTBEAT_CONFIG = {"every": "10m", "target": "none"}
|
||||
DEFAULT_IDENTITY_PROFILE = {
|
||||
@@ -35,7 +45,8 @@ EXTRA_IDENTITY_PROFILE_FIELDS = {
|
||||
"verbosity": "identity_verbosity",
|
||||
"output_format": "identity_output_format",
|
||||
"update_cadence": "identity_update_cadence",
|
||||
# Per-agent charter (optional). Used to give agents a "purpose in life" and a distinct vibe.
|
||||
# Per-agent charter (optional).
|
||||
# Used to give agents a "purpose in life" and a distinct vibe.
|
||||
"purpose": "identity_purpose",
|
||||
"personality": "identity_personality",
|
||||
"custom_instructions": "identity_custom_instructions",
|
||||
@@ -54,11 +65,11 @@ DEFAULT_GATEWAY_FILES = frozenset(
|
||||
"BOOT.md",
|
||||
"BOOTSTRAP.md",
|
||||
"MEMORY.md",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# These files are intended to evolve within the agent workspace. Provision them if missing,
|
||||
# but avoid overwriting existing content during updates.
|
||||
# These files are intended to evolve within the agent workspace.
|
||||
# Provision them if missing, but avoid overwriting existing content during updates.
|
||||
#
|
||||
# Examples:
|
||||
# - SELF.md: evolving identity/preferences
|
||||
@@ -68,6 +79,7 @@ PRESERVE_AGENT_EDITABLE_FILES = frozenset({"SELF.md", "USER.md", "MEMORY.md"})
|
||||
|
||||
HEARTBEAT_LEAD_TEMPLATE = "HEARTBEAT_LEAD.md"
|
||||
HEARTBEAT_AGENT_TEMPLATE = "HEARTBEAT_AGENT.md"
|
||||
_SESSION_KEY_PARTS_MIN = 2
|
||||
MAIN_TEMPLATE_MAP = {
|
||||
"AGENTS.md": "MAIN_AGENTS.md",
|
||||
"HEARTBEAT.md": "MAIN_HEARTBEAT.md",
|
||||
@@ -97,13 +109,13 @@ def _agent_id_from_session_key(session_key: str | None) -> str | None:
|
||||
if not value.startswith("agent:"):
|
||||
return None
|
||||
parts = value.split(":")
|
||||
if len(parts) < 2:
|
||||
if len(parts) < _SESSION_KEY_PARTS_MIN:
|
||||
return None
|
||||
agent_id = parts[1].strip()
|
||||
return agent_id or None
|
||||
|
||||
|
||||
def _extract_agent_id(payload: object) -> str | None:
|
||||
def _extract_agent_id(payload: object) -> str | None: # noqa: C901
|
||||
def _from_list(items: object) -> str | None:
|
||||
if not isinstance(items, list):
|
||||
return None
|
||||
@@ -137,7 +149,7 @@ def _agent_key(agent: Agent) -> str:
|
||||
session_key = agent.openclaw_session_id or ""
|
||||
if session_key.startswith("agent:"):
|
||||
parts = session_key.split(":")
|
||||
if len(parts) >= 2 and parts[1]:
|
||||
if len(parts) >= _SESSION_KEY_PARTS_MIN and parts[1]:
|
||||
return parts[1]
|
||||
return _slugify(agent.name)
|
||||
|
||||
@@ -183,14 +195,14 @@ def _ensure_workspace_file(
|
||||
if not workspace_path or not name:
|
||||
return
|
||||
# Only write to a dedicated, explicitly-configured local directory.
|
||||
# Using `gateway.workspace_root` directly here is unsafe (and CodeQL correctly flags it)
|
||||
# because it is a DB-backed config value.
|
||||
# Using `gateway.workspace_root` directly here is unsafe.
|
||||
# CodeQL correctly flags that value because it is DB-backed config.
|
||||
base_root = (settings.local_agent_workspace_root or "").strip()
|
||||
if not base_root:
|
||||
return
|
||||
base = Path(base_root).expanduser()
|
||||
|
||||
# Derive a stable, safe directory name from the (potentially untrusted) workspace path.
|
||||
# Derive a stable, safe directory name from the untrusted workspace path.
|
||||
# This prevents path traversal and avoids writing to arbitrary locations.
|
||||
digest = hashlib.sha256(workspace_path.encode("utf-8")).hexdigest()[:16]
|
||||
root = base / f"gateway-workspace-{digest}"
|
||||
@@ -345,12 +357,14 @@ async def _supported_gateway_files(config: GatewayClientConfig) -> set[str]:
|
||||
default_id = None
|
||||
if isinstance(agents_payload, dict):
|
||||
agents = list(agents_payload.get("agents") or [])
|
||||
default_id = agents_payload.get("defaultId") or agents_payload.get("default_id")
|
||||
default_id = agents_payload.get("defaultId") or agents_payload.get(
|
||||
"default_id",
|
||||
)
|
||||
agent_id = default_id or (agents[0].get("id") if agents else None)
|
||||
if not agent_id:
|
||||
return set(DEFAULT_GATEWAY_FILES)
|
||||
files_payload = await openclaw_call(
|
||||
"agents.files.list", {"agentId": agent_id}, config=config
|
||||
"agents.files.list", {"agentId": agent_id}, config=config,
|
||||
)
|
||||
if isinstance(files_payload, dict):
|
||||
files = files_payload.get("files") or []
|
||||
@@ -374,10 +388,12 @@ async def _reset_session(session_key: str, config: GatewayClientConfig) -> None:
|
||||
|
||||
|
||||
async def _gateway_agent_files_index(
|
||||
agent_id: str, config: GatewayClientConfig
|
||||
agent_id: str, config: GatewayClientConfig,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
try:
|
||||
payload = await openclaw_call("agents.files.list", {"agentId": agent_id}, config=config)
|
||||
payload = await openclaw_call(
|
||||
"agents.files.list", {"agentId": agent_id}, config=config,
|
||||
)
|
||||
if isinstance(payload, dict):
|
||||
files = payload.get("files") or []
|
||||
index: dict[str, dict[str, Any]] = {}
|
||||
@@ -420,21 +436,25 @@ def _render_agent_files(
|
||||
)
|
||||
heartbeat_path = _templates_root() / heartbeat_template
|
||||
if heartbeat_path.exists():
|
||||
rendered[name] = env.get_template(heartbeat_template).render(**context).strip()
|
||||
rendered[name] = (
|
||||
env.get_template(heartbeat_template).render(**context).strip()
|
||||
)
|
||||
continue
|
||||
override = overrides.get(name)
|
||||
if override:
|
||||
rendered[name] = env.from_string(override).render(**context).strip()
|
||||
continue
|
||||
template_name = (
|
||||
template_overrides[name] if template_overrides and name in template_overrides else name
|
||||
template_overrides[name]
|
||||
if template_overrides and name in template_overrides
|
||||
else name
|
||||
)
|
||||
path = _templates_root() / template_name
|
||||
if path.exists():
|
||||
rendered[name] = env.get_template(template_name).render(**context).strip()
|
||||
continue
|
||||
if name == "MEMORY.md":
|
||||
# Back-compat fallback for existing gateways that don't ship a MEMORY.md template.
|
||||
# Back-compat fallback for gateways that do not ship MEMORY.md.
|
||||
rendered[name] = "# MEMORY\n\nBootstrap pending.\n"
|
||||
continue
|
||||
rendered[name] = ""
|
||||
@@ -487,7 +507,9 @@ async def _patch_gateway_agent_list(
|
||||
else:
|
||||
new_list.append(entry)
|
||||
if not updated:
|
||||
new_list.append({"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat})
|
||||
new_list.append(
|
||||
{"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat},
|
||||
)
|
||||
|
||||
patch = {"agents": {"list": new_list}}
|
||||
params = {"raw": json.dumps(patch)}
|
||||
@@ -496,7 +518,7 @@ async def _patch_gateway_agent_list(
|
||||
await openclaw_call("config.patch", params, config=config)
|
||||
|
||||
|
||||
async def patch_gateway_agent_heartbeats(
|
||||
async def patch_gateway_agent_heartbeats( # noqa: C901
|
||||
gateway: Gateway,
|
||||
*,
|
||||
entries: list[tuple[str, str, dict[str, Any]]],
|
||||
@@ -521,7 +543,8 @@ async def patch_gateway_agent_heartbeats(
|
||||
raise OpenClawGatewayError("config agents.list is not a list")
|
||||
|
||||
entry_by_id: dict[str, tuple[str, dict[str, Any]]] = {
|
||||
agent_id: (workspace_path, heartbeat) for agent_id, workspace_path, heartbeat in entries
|
||||
agent_id: (workspace_path, heartbeat)
|
||||
for agent_id, workspace_path, heartbeat in entries
|
||||
}
|
||||
|
||||
updated_ids: set[str] = set()
|
||||
@@ -544,7 +567,9 @@ async def patch_gateway_agent_heartbeats(
|
||||
for agent_id, (workspace_path, heartbeat) in entry_by_id.items():
|
||||
if agent_id in updated_ids:
|
||||
continue
|
||||
new_list.append({"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat})
|
||||
new_list.append(
|
||||
{"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat},
|
||||
)
|
||||
|
||||
patch = {"agents": {"list": new_list}}
|
||||
params = {"raw": json.dumps(patch)}
|
||||
@@ -585,7 +610,9 @@ async def _remove_gateway_agent_list(
|
||||
raise OpenClawGatewayError("config agents.list is not a list")
|
||||
|
||||
new_list = [
|
||||
entry for entry in lst if not (isinstance(entry, dict) and entry.get("id") == agent_id)
|
||||
entry
|
||||
for entry in lst
|
||||
if not (isinstance(entry, dict) and entry.get("id") == agent_id)
|
||||
]
|
||||
if len(new_list) == len(lst):
|
||||
return
|
||||
@@ -616,7 +643,7 @@ async def _get_gateway_agent_entry(
|
||||
return None
|
||||
|
||||
|
||||
async def provision_agent(
|
||||
async def provision_agent( # noqa: C901, PLR0912, PLR0913
|
||||
agent: Agent,
|
||||
board: Board,
|
||||
gateway: Gateway,
|
||||
@@ -627,6 +654,7 @@ async def provision_agent(
|
||||
force_bootstrap: bool = False,
|
||||
reset_session: bool = False,
|
||||
) -> None:
|
||||
"""Provision or update a regular board agent workspace."""
|
||||
if not gateway.url:
|
||||
return
|
||||
if not gateway.workspace_root:
|
||||
@@ -665,11 +693,9 @@ async def provision_agent(
|
||||
content = rendered.get(name)
|
||||
if not content:
|
||||
continue
|
||||
try:
|
||||
_ensure_workspace_file(workspace_path, name, content, overwrite=False)
|
||||
except OSError:
|
||||
with suppress(OSError):
|
||||
# Local workspace may not be writable/available; fall back to gateway API.
|
||||
pass
|
||||
_ensure_workspace_file(workspace_path, name, content, overwrite=False)
|
||||
for name, content in rendered.items():
|
||||
if content == "":
|
||||
continue
|
||||
@@ -694,7 +720,7 @@ async def provision_agent(
|
||||
await _reset_session(session_key, client_config)
|
||||
|
||||
|
||||
async def provision_main_agent(
|
||||
async def provision_main_agent( # noqa: C901, PLR0912, PLR0913
|
||||
agent: Agent,
|
||||
gateway: Gateway,
|
||||
auth_token: str,
|
||||
@@ -704,12 +730,15 @@ async def provision_main_agent(
|
||||
force_bootstrap: bool = False,
|
||||
reset_session: bool = False,
|
||||
) -> None:
|
||||
"""Provision or update the gateway main agent workspace."""
|
||||
if not gateway.url:
|
||||
return
|
||||
if not gateway.main_session_key:
|
||||
raise ValueError("gateway main_session_key is required")
|
||||
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
|
||||
await ensure_session(gateway.main_session_key, config=client_config, label="Main Agent")
|
||||
await ensure_session(
|
||||
gateway.main_session_key, config=client_config, label="Main Agent",
|
||||
)
|
||||
|
||||
agent_id = await _gateway_default_agent_id(
|
||||
client_config,
|
||||
@@ -763,6 +792,7 @@ async def cleanup_agent(
|
||||
agent: Agent,
|
||||
gateway: Gateway,
|
||||
) -> str | None:
|
||||
"""Remove an agent from gateway config and delete its session."""
|
||||
if not gateway.url:
|
||||
return None
|
||||
if not gateway.workspace_root:
|
||||
|
||||
@@ -1,30 +1,41 @@
|
||||
"""Helpers for ensuring each board has a provisioned lead agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.agent_tokens import generate_agent_token, hash_agent_token
|
||||
from app.core.time import utcnow
|
||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
|
||||
from app.integrations.openclaw_gateway import (
|
||||
OpenClawGatewayError,
|
||||
ensure_session,
|
||||
send_message,
|
||||
)
|
||||
from app.models.agents import Agent
|
||||
from app.models.boards import Board
|
||||
from app.models.gateways import Gateway
|
||||
from app.models.users import User
|
||||
from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_agent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.boards import Board
|
||||
from app.models.gateways import Gateway
|
||||
from app.models.users import User
|
||||
|
||||
|
||||
def lead_session_key(board: Board) -> str:
|
||||
"""Return the deterministic main session key for a board lead agent."""
|
||||
return f"agent:lead-{board.id}:main"
|
||||
|
||||
|
||||
def lead_agent_name(_: Board) -> str:
|
||||
"""Return the default display name for board lead agents."""
|
||||
return "Lead Agent"
|
||||
|
||||
|
||||
async def ensure_board_lead_agent(
|
||||
async def ensure_board_lead_agent( # noqa: PLR0913
|
||||
session: AsyncSession,
|
||||
*,
|
||||
board: Board,
|
||||
@@ -35,11 +46,12 @@ async def ensure_board_lead_agent(
|
||||
identity_profile: dict[str, str] | None = None,
|
||||
action: str = "provision",
|
||||
) -> tuple[Agent, bool]:
|
||||
"""Ensure a board has a lead agent; return `(agent, created)`."""
|
||||
existing = (
|
||||
await session.exec(
|
||||
select(Agent)
|
||||
.where(Agent.board_id == board.id)
|
||||
.where(col(Agent.is_board_lead).is_(True))
|
||||
.where(col(Agent.is_board_lead).is_(True)),
|
||||
)
|
||||
).first()
|
||||
if existing:
|
||||
@@ -66,7 +78,11 @@ async def ensure_board_lead_agent(
|
||||
}
|
||||
if identity_profile:
|
||||
merged_identity_profile.update(
|
||||
{key: value.strip() for key, value in identity_profile.items() if value.strip()}
|
||||
{
|
||||
key: value.strip()
|
||||
for key, value in identity_profile.items()
|
||||
if value.strip()
|
||||
},
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
@@ -89,11 +105,16 @@ async def ensure_board_lead_agent(
|
||||
try:
|
||||
await provision_agent(agent, board, gateway, raw_token, user, action=action)
|
||||
if agent.openclaw_session_id:
|
||||
await ensure_session(agent.openclaw_session_id, config=config, label=agent.name)
|
||||
await ensure_session(
|
||||
agent.openclaw_session_id,
|
||||
config=config,
|
||||
label=agent.name,
|
||||
)
|
||||
await send_message(
|
||||
(
|
||||
f"Hello {agent.name}. Your workspace has been provisioned.\n\n"
|
||||
"Start the agent, run BOOT.md, and if BOOTSTRAP.md exists run it once "
|
||||
"Start the agent, run BOOT.md, and if BOOTSTRAP.md exists run "
|
||||
"it once "
|
||||
"then delete it. Begin heartbeats after startup."
|
||||
),
|
||||
session_key=agent.openclaw_session_id,
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
"""Helpers for assembling denormalized board snapshot response payloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import case, func
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.time import utcnow
|
||||
from app.models.agents import Agent
|
||||
from app.models.approvals import Approval
|
||||
from app.models.board_memory import BoardMemory
|
||||
from app.models.boards import Board
|
||||
from app.models.gateways import Gateway
|
||||
from app.models.tasks import Task
|
||||
from app.schemas.agents import AgentRead
|
||||
@@ -25,6 +25,13 @@ from app.services.task_dependencies import (
|
||||
dependency_status_by_id,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.boards import Board
|
||||
|
||||
OFFLINE_AFTER = timedelta(minutes=10)
|
||||
|
||||
|
||||
@@ -48,9 +55,15 @@ def _agent_to_read(agent: Agent, main_session_keys: set[str]) -> AgentRead:
|
||||
model = AgentRead.model_validate(agent, from_attributes=True)
|
||||
computed_status = _computed_agent_status(agent)
|
||||
is_gateway_main = bool(
|
||||
agent.openclaw_session_id and agent.openclaw_session_id in main_session_keys
|
||||
agent.openclaw_session_id
|
||||
and agent.openclaw_session_id in main_session_keys,
|
||||
)
|
||||
return model.model_copy(
|
||||
update={
|
||||
"status": computed_status,
|
||||
"is_gateway_main": is_gateway_main,
|
||||
},
|
||||
)
|
||||
return model.model_copy(update={"status": computed_status, "is_gateway_main": is_gateway_main})
|
||||
|
||||
|
||||
def _memory_to_read(memory: BoardMemory) -> BoardMemoryRead:
|
||||
@@ -72,7 +85,9 @@ def _task_to_card(
|
||||
card = TaskCardRead.model_validate(task, from_attributes=True)
|
||||
approvals_count, approvals_pending_count = counts_by_task_id.get(task.id, (0, 0))
|
||||
assignee = (
|
||||
agent_name_by_id.get(task.assigned_agent_id) if task.assigned_agent_id is not None else None
|
||||
agent_name_by_id.get(task.assigned_agent_id)
|
||||
if task.assigned_agent_id
|
||||
else None
|
||||
)
|
||||
depends_on_task_ids = deps_by_task_id.get(task.id, [])
|
||||
blocked_by_task_ids = blocked_by_dependency_ids(
|
||||
@@ -89,21 +104,26 @@ def _task_to_card(
|
||||
"depends_on_task_ids": depends_on_task_ids,
|
||||
"blocked_by_task_ids": blocked_by_task_ids,
|
||||
"is_blocked": bool(blocked_by_task_ids),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnapshot:
|
||||
"""Build a board snapshot with tasks, agents, approvals, and chat history."""
|
||||
board_read = BoardRead.model_validate(board, from_attributes=True)
|
||||
|
||||
tasks = list(
|
||||
await Task.objects.filter_by(board_id=board.id)
|
||||
.order_by(col(Task.created_at).desc())
|
||||
.all(session)
|
||||
.all(session),
|
||||
)
|
||||
task_ids = [task.id for task in tasks]
|
||||
|
||||
deps_by_task_id = await dependency_ids_by_task_id(session, board_id=board.id, task_ids=task_ids)
|
||||
deps_by_task_id = await dependency_ids_by_task_id(
|
||||
session,
|
||||
board_id=board.id,
|
||||
task_ids=task_ids,
|
||||
)
|
||||
all_dependency_ids: list[UUID] = []
|
||||
for values in deps_by_task_id.values():
|
||||
all_dependency_ids.extend(values)
|
||||
@@ -127,9 +147,9 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
|
||||
await session.exec(
|
||||
select(func.count(col(Approval.id)))
|
||||
.where(col(Approval.board_id) == board.id)
|
||||
.where(col(Approval.status) == "pending")
|
||||
)
|
||||
).one()
|
||||
.where(col(Approval.status) == "pending"),
|
||||
),
|
||||
).one(),
|
||||
)
|
||||
|
||||
approvals = (
|
||||
@@ -146,12 +166,14 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
|
||||
select(
|
||||
col(Approval.task_id),
|
||||
func.count(col(Approval.id)).label("total"),
|
||||
func.sum(case((col(Approval.status) == "pending", 1), else_=0)).label("pending"),
|
||||
func.sum(
|
||||
case((col(Approval.status) == "pending", 1), else_=0),
|
||||
).label("pending"),
|
||||
)
|
||||
.where(col(Approval.board_id) == board.id)
|
||||
.where(col(Approval.task_id).is_not(None))
|
||||
.group_by(col(Approval.task_id))
|
||||
)
|
||||
.group_by(col(Approval.task_id)),
|
||||
),
|
||||
)
|
||||
for task_id, total, pending in rows:
|
||||
if task_id is None:
|
||||
|
||||
@@ -1,26 +1,33 @@
|
||||
"""Policy helpers for lead-agent approval and planning decisions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Mapping
|
||||
|
||||
CONFIDENCE_THRESHOLD = 80
|
||||
MIN_PLANNING_SIGNALS = 2
|
||||
|
||||
|
||||
def compute_confidence(rubric_scores: Mapping[str, int]) -> int:
|
||||
"""Compute aggregate confidence from rubric score components."""
|
||||
return int(sum(rubric_scores.values()))
|
||||
|
||||
|
||||
def approval_required(*, confidence: int, is_external: bool, is_risky: bool) -> bool:
|
||||
"""Return whether an action must go through explicit approval."""
|
||||
return is_external or is_risky or confidence < CONFIDENCE_THRESHOLD
|
||||
|
||||
|
||||
def infer_planning(signals: Mapping[str, bool]) -> bool:
|
||||
"""Infer planning intent from boolean heuristic signals."""
|
||||
# Require at least two planning signals to avoid spam on general boards.
|
||||
truthy = [key for key, value in signals.items() if value]
|
||||
return len(truthy) >= 2
|
||||
return len(truthy) >= MIN_PLANNING_SIGNALS
|
||||
|
||||
|
||||
def task_fingerprint(title: str, description: str | None, board_id: str) -> str:
|
||||
"""Build a stable hash key for deduplicating similar board tasks."""
|
||||
normalized_title = title.strip().lower()
|
||||
normalized_desc = (description or "").strip().lower()
|
||||
seed = f"{board_id}::{normalized_title}::{normalized_desc}"
|
||||
|
||||
@@ -1,18 +1,24 @@
|
||||
"""Helpers for extracting and matching `@mention` tokens in text."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.agents import Agent
|
||||
if TYPE_CHECKING:
|
||||
from app.models.agents import Agent
|
||||
|
||||
# Mention tokens are single, space-free words (e.g. "@alex", "@lead").
|
||||
MENTION_PATTERN = re.compile(r"@([A-Za-z][\w-]{0,31})")
|
||||
|
||||
|
||||
def extract_mentions(message: str) -> set[str]:
|
||||
"""Extract normalized mention handles from a message body."""
|
||||
return {match.group(1).lower() for match in MENTION_PATTERN.finditer(message)}
|
||||
|
||||
|
||||
def matches_agent_mention(agent: Agent, mentions: set[str]) -> bool:
|
||||
"""Return whether a mention set targets the provided agent."""
|
||||
if not mentions:
|
||||
return False
|
||||
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
"""Organization membership and board-access service helpers."""
|
||||
# ruff: noqa: D101, D103
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable
|
||||
from uuid import UUID
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.time import utcnow
|
||||
from app.db import crud
|
||||
@@ -19,7 +19,17 @@ from app.models.organization_invites import OrganizationInvite
|
||||
from app.models.organization_members import OrganizationMember
|
||||
from app.models.organizations import Organization
|
||||
from app.models.users import User
|
||||
from app.schemas.organizations import OrganizationBoardAccessSpec, OrganizationMemberAccessUpdate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.schemas.organizations import (
|
||||
OrganizationBoardAccessSpec,
|
||||
OrganizationMemberAccessUpdate,
|
||||
)
|
||||
|
||||
DEFAULT_ORG_NAME = "Personal"
|
||||
ADMIN_ROLES = {"owner", "admin"}
|
||||
@@ -63,7 +73,9 @@ async def get_member(
|
||||
).first(session)
|
||||
|
||||
|
||||
async def get_first_membership(session: AsyncSession, user_id: UUID) -> OrganizationMember | None:
|
||||
async def get_first_membership(
|
||||
session: AsyncSession, user_id: UUID,
|
||||
) -> OrganizationMember | None:
|
||||
return (
|
||||
await OrganizationMember.objects.filter_by(user_id=user_id)
|
||||
.order_by(col(OrganizationMember.created_at).asc())
|
||||
@@ -79,7 +91,9 @@ async def set_active_organization(
|
||||
) -> OrganizationMember:
|
||||
member = await get_member(session, user_id=user.id, organization_id=organization_id)
|
||||
if member is None:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No org access")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="No org access",
|
||||
)
|
||||
if user.active_organization_id != organization_id:
|
||||
user.active_organization_id = organization_id
|
||||
session.add(user)
|
||||
@@ -154,9 +168,10 @@ async def accept_invite(
|
||||
access_rows = list(
|
||||
await session.exec(
|
||||
select(OrganizationInviteBoardAccess).where(
|
||||
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
|
||||
)
|
||||
)
|
||||
col(OrganizationInviteBoardAccess.organization_invite_id)
|
||||
== invite.id,
|
||||
),
|
||||
),
|
||||
)
|
||||
for row in access_rows:
|
||||
session.add(
|
||||
@@ -167,7 +182,7 @@ async def accept_invite(
|
||||
can_write=row.can_write,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
invite.accepted_by_user_id = user.id
|
||||
@@ -182,7 +197,9 @@ async def accept_invite(
|
||||
return member
|
||||
|
||||
|
||||
async def ensure_member_for_user(session: AsyncSession, user: User) -> OrganizationMember:
|
||||
async def ensure_member_for_user(
|
||||
session: AsyncSession, user: User,
|
||||
) -> OrganizationMember:
|
||||
existing = await get_active_membership(session, user)
|
||||
if existing is not None:
|
||||
return existing
|
||||
@@ -196,7 +213,9 @@ async def ensure_member_for_user(session: AsyncSession, user: User) -> Organizat
|
||||
now = utcnow()
|
||||
member_count = (
|
||||
await session.exec(
|
||||
select(func.count()).where(col(OrganizationMember.organization_id) == org.id)
|
||||
select(func.count()).where(
|
||||
col(OrganizationMember.organization_id) == org.id,
|
||||
),
|
||||
)
|
||||
).one()
|
||||
is_first = int(member_count or 0) == 0
|
||||
@@ -257,30 +276,40 @@ async def require_board_access(
|
||||
board: Board,
|
||||
write: bool,
|
||||
) -> OrganizationMember:
|
||||
member = await get_member(session, user_id=user.id, organization_id=board.organization_id)
|
||||
member = await get_member(
|
||||
session, user_id=user.id, organization_id=board.organization_id,
|
||||
)
|
||||
if member is None:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No org access")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="No org access",
|
||||
)
|
||||
if not await has_board_access(session, member=member, board=board, write=write):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Board access denied")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Board access denied",
|
||||
)
|
||||
return member
|
||||
|
||||
|
||||
def board_access_filter(member: OrganizationMember, *, write: bool) -> ColumnElement[bool]:
|
||||
def board_access_filter(
|
||||
member: OrganizationMember, *, write: bool,
|
||||
) -> ColumnElement[bool]:
|
||||
if write and member_all_boards_write(member):
|
||||
return col(Board.organization_id) == member.organization_id
|
||||
if not write and member_all_boards_read(member):
|
||||
return col(Board.organization_id) == member.organization_id
|
||||
access_stmt = select(OrganizationBoardAccess.board_id).where(
|
||||
col(OrganizationBoardAccess.organization_member_id) == member.id
|
||||
col(OrganizationBoardAccess.organization_member_id) == member.id,
|
||||
)
|
||||
if write:
|
||||
access_stmt = access_stmt.where(col(OrganizationBoardAccess.can_write).is_(True))
|
||||
access_stmt = access_stmt.where(
|
||||
col(OrganizationBoardAccess.can_write).is_(True),
|
||||
)
|
||||
else:
|
||||
access_stmt = access_stmt.where(
|
||||
or_(
|
||||
col(OrganizationBoardAccess.can_read).is_(True),
|
||||
col(OrganizationBoardAccess.can_write).is_(True),
|
||||
)
|
||||
),
|
||||
)
|
||||
return col(Board.id).in_(access_stmt)
|
||||
|
||||
@@ -295,21 +324,25 @@ async def list_accessible_board_ids(
|
||||
not write and member_all_boards_read(member)
|
||||
):
|
||||
ids = await session.exec(
|
||||
select(Board.id).where(col(Board.organization_id) == member.organization_id)
|
||||
select(Board.id).where(
|
||||
col(Board.organization_id) == member.organization_id,
|
||||
),
|
||||
)
|
||||
return list(ids)
|
||||
|
||||
access_stmt = select(OrganizationBoardAccess.board_id).where(
|
||||
col(OrganizationBoardAccess.organization_member_id) == member.id
|
||||
col(OrganizationBoardAccess.organization_member_id) == member.id,
|
||||
)
|
||||
if write:
|
||||
access_stmt = access_stmt.where(col(OrganizationBoardAccess.can_write).is_(True))
|
||||
access_stmt = access_stmt.where(
|
||||
col(OrganizationBoardAccess.can_write).is_(True),
|
||||
)
|
||||
else:
|
||||
access_stmt = access_stmt.where(
|
||||
or_(
|
||||
col(OrganizationBoardAccess.can_read).is_(True),
|
||||
col(OrganizationBoardAccess.can_write).is_(True),
|
||||
)
|
||||
),
|
||||
)
|
||||
board_ids = await session.exec(access_stmt)
|
||||
return list(board_ids)
|
||||
@@ -337,18 +370,17 @@ async def apply_member_access_update(
|
||||
if update.all_boards_read or update.all_boards_write:
|
||||
return
|
||||
|
||||
rows: list[OrganizationBoardAccess] = []
|
||||
for entry in update.board_access:
|
||||
rows.append(
|
||||
OrganizationBoardAccess(
|
||||
organization_member_id=member.id,
|
||||
board_id=entry.board_id,
|
||||
can_read=entry.can_read,
|
||||
can_write=entry.can_write,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
rows = [
|
||||
OrganizationBoardAccess(
|
||||
organization_member_id=member.id,
|
||||
board_id=entry.board_id,
|
||||
can_read=entry.can_read,
|
||||
can_write=entry.can_write,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
for entry in update.board_access
|
||||
]
|
||||
session.add_all(rows)
|
||||
|
||||
|
||||
@@ -367,18 +399,17 @@ async def apply_invite_board_access(
|
||||
if invite.all_boards_read or invite.all_boards_write:
|
||||
return
|
||||
now = utcnow()
|
||||
rows: list[OrganizationInviteBoardAccess] = []
|
||||
for entry in entries:
|
||||
rows.append(
|
||||
OrganizationInviteBoardAccess(
|
||||
organization_invite_id=invite.id,
|
||||
board_id=entry.board_id,
|
||||
can_read=entry.can_read,
|
||||
can_write=entry.can_write,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
rows = [
|
||||
OrganizationInviteBoardAccess(
|
||||
organization_invite_id=invite.id,
|
||||
board_id=entry.board_id,
|
||||
can_read=entry.can_read,
|
||||
can_write=entry.can_write,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
for entry in entries
|
||||
]
|
||||
session.add_all(rows)
|
||||
|
||||
|
||||
@@ -423,9 +454,9 @@ async def apply_invite_to_member(
|
||||
access_rows = list(
|
||||
await session.exec(
|
||||
select(OrganizationInviteBoardAccess).where(
|
||||
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
|
||||
)
|
||||
)
|
||||
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id,
|
||||
),
|
||||
),
|
||||
)
|
||||
for row in access_rows:
|
||||
existing = (
|
||||
@@ -433,7 +464,7 @@ async def apply_invite_to_member(
|
||||
select(OrganizationBoardAccess).where(
|
||||
col(OrganizationBoardAccess.organization_member_id) == member.id,
|
||||
col(OrganizationBoardAccess.board_id) == row.board_id,
|
||||
)
|
||||
),
|
||||
)
|
||||
).first()
|
||||
can_write = bool(row.can_write)
|
||||
@@ -447,7 +478,7 @@ async def apply_invite_to_member(
|
||||
can_write=can_write,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
existing.can_read = bool(existing.can_read or can_read)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Service helpers for querying and caching souls.directory content."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
@@ -11,33 +13,41 @@ SOULS_DIRECTORY_BASE_URL: Final[str] = "https://souls.directory"
|
||||
SOULS_DIRECTORY_SITEMAP_URL: Final[str] = f"{SOULS_DIRECTORY_BASE_URL}/sitemap.xml"
|
||||
|
||||
_SITEMAP_TTL_SECONDS: Final[int] = 60 * 60
|
||||
_SOUL_URL_MIN_PARTS: Final[int] = 6
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SoulRef:
|
||||
"""Handle/slug reference pair for a soul entry."""
|
||||
|
||||
handle: str
|
||||
slug: str
|
||||
|
||||
@property
|
||||
def page_url(self) -> str:
|
||||
"""Return the canonical page URL for this soul."""
|
||||
return f"{SOULS_DIRECTORY_BASE_URL}/souls/{self.handle}/{self.slug}"
|
||||
|
||||
@property
|
||||
def raw_md_url(self) -> str:
|
||||
"""Return the raw markdown URL for this soul."""
|
||||
return f"{SOULS_DIRECTORY_BASE_URL}/api/souls/{self.handle}/{self.slug}.md"
|
||||
|
||||
|
||||
def _parse_sitemap_soul_refs(sitemap_xml: str) -> list[SoulRef]:
|
||||
"""Parse sitemap XML and extract valid souls.directory handle/slug refs."""
|
||||
try:
|
||||
root = ET.fromstring(sitemap_xml)
|
||||
# Souls sitemap is fetched from a known trusted host in this service flow.
|
||||
root = ET.fromstring(sitemap_xml) # noqa: S314
|
||||
except ET.ParseError:
|
||||
return []
|
||||
|
||||
# Handle both namespaced and non-namespaced sitemap XML.
|
||||
urls: list[str] = []
|
||||
for loc in root.iter():
|
||||
if loc.tag.endswith("loc") and loc.text:
|
||||
urls.append(loc.text.strip())
|
||||
urls = [
|
||||
loc.text.strip()
|
||||
for loc in root.iter()
|
||||
if loc.tag.endswith("loc") and loc.text
|
||||
]
|
||||
|
||||
refs: list[SoulRef] = []
|
||||
for url in urls:
|
||||
@@ -45,7 +55,7 @@ def _parse_sitemap_soul_refs(sitemap_xml: str) -> list[SoulRef]:
|
||||
continue
|
||||
# Expected: https://souls.directory/souls/{handle}/{slug}
|
||||
parts = url.split("/")
|
||||
if len(parts) < 6:
|
||||
if len(parts) < _SOUL_URL_MIN_PARTS:
|
||||
continue
|
||||
handle = parts[4].strip()
|
||||
slug = parts[5].strip()
|
||||
@@ -61,7 +71,11 @@ _sitemap_cache: dict[str, object] = {
|
||||
}
|
||||
|
||||
|
||||
async def list_souls_directory_refs(*, client: httpx.AsyncClient | None = None) -> list[SoulRef]:
|
||||
async def list_souls_directory_refs(
|
||||
*,
|
||||
client: httpx.AsyncClient | None = None,
|
||||
) -> list[SoulRef]:
|
||||
"""Return cached sitemap-derived soul refs, refreshing when TTL expires."""
|
||||
now = time.time()
|
||||
loaded_raw = _sitemap_cache.get("loaded_at")
|
||||
loaded_at = loaded_raw if isinstance(loaded_raw, (int, float)) else 0.0
|
||||
@@ -93,11 +107,15 @@ async def fetch_soul_markdown(
|
||||
slug: str,
|
||||
client: httpx.AsyncClient | None = None,
|
||||
) -> str:
|
||||
"""Fetch raw markdown content for a specific handle/slug pair."""
|
||||
normalized_handle = handle.strip().strip("/")
|
||||
normalized_slug = slug.strip().strip("/")
|
||||
if normalized_slug.endswith(".md"):
|
||||
normalized_slug = normalized_slug[: -len(".md")]
|
||||
url = f"{SOULS_DIRECTORY_BASE_URL}/api/souls/{normalized_handle}/{normalized_slug}.md"
|
||||
url = (
|
||||
f"{SOULS_DIRECTORY_BASE_URL}/api/souls/"
|
||||
f"{normalized_handle}/{normalized_slug}.md"
|
||||
)
|
||||
|
||||
owns_client = client is None
|
||||
if client is None:
|
||||
@@ -115,6 +133,7 @@ async def fetch_soul_markdown(
|
||||
|
||||
|
||||
def search_souls(refs: list[SoulRef], *, query: str, limit: int = 20) -> list[SoulRef]:
|
||||
"""Search refs by case-insensitive handle/slug substring with a hard limit."""
|
||||
q = query.strip().lower()
|
||||
if not q:
|
||||
return refs[: max(0, min(limit, len(refs)))]
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Background worker tasks and queue processing utilities."""
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""RQ queue and Redis connection helpers for background workers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from redis import Redis
|
||||
@@ -7,8 +9,10 @@ from app.core.config import settings
|
||||
|
||||
|
||||
def get_redis() -> Redis:
|
||||
"""Create a Redis client from configured settings."""
|
||||
return Redis.from_url(settings.redis_url)
|
||||
|
||||
|
||||
def get_queue(name: str) -> Queue:
|
||||
"""Return an RQ queue bound to the configured Redis connection."""
|
||||
return Queue(name, connection=get_redis())
|
||||
|
||||
Reference in New Issue
Block a user