refactor: update module docstrings for clarity and consistency

This commit is contained in:
Abhimanyu Saharan
2026-02-09 15:49:50 +05:30
parent 78bb08d4a3
commit 7ca1899d9f
99 changed files with 2345 additions and 855 deletions

View File

@@ -0,0 +1 @@
"""OpenClaw Mission Control backend application package."""

View File

@@ -0,0 +1 @@
"""API router modules for the OpenClaw Mission Control backend."""

View File

@@ -1,13 +1,14 @@
"""Agent-scoped API routes for board operations and gateway coordination."""
from __future__ import annotations
import re
from collections.abc import Sequence
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel import SQLModel, col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api import agents as agents_api
from app.api import approvals as approvals_api
@@ -27,11 +28,7 @@ from app.integrations.openclaw_gateway import (
openclaw_call,
send_message,
)
from app.models.activity_events import ActivityEvent
from app.models.agents import Agent
from app.models.approvals import Approval
from app.models.board_memory import BoardMemory
from app.models.board_onboarding import BoardOnboardingSession
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.task_dependencies import TaskDependency
@@ -58,7 +55,13 @@ from app.schemas.gateway_coordination import (
GatewayMainAskUserResponse,
)
from app.schemas.pagination import DefaultLimitOffsetPage
from app.schemas.tasks import TaskCommentCreate, TaskCommentRead, TaskCreate, TaskRead, TaskUpdate
from app.schemas.tasks import (
TaskCommentCreate,
TaskCommentRead,
TaskCreate,
TaskRead,
TaskUpdate,
)
from app.services.activity_log import record_activity
from app.services.board_leads import ensure_board_lead_agent
from app.services.task_dependencies import (
@@ -67,11 +70,27 @@ from app.services.task_dependencies import (
validate_dependency_update,
)
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.activity_events import ActivityEvent
from app.models.approvals import Approval
from app.models.board_memory import BoardMemory
from app.models.board_onboarding import BoardOnboardingSession
router = APIRouter(prefix="/agent", tags=["agent"])
_AGENT_SESSION_PREFIX = "agent:"
_SESSION_KEY_PARTS_MIN = 2
_LEAD_SESSION_KEY_MISSING = "Lead agent has no session key"
SESSION_DEP = Depends(get_session)
AGENT_CTX_DEP = Depends(get_agent_auth_context)
BOARD_DEP = Depends(get_board_or_404)
TASK_DEP = Depends(get_task_or_404)
BOARD_ID_QUERY = Query(default=None)
TASK_STATUS_QUERY = Query(default=None, alias="status")
IS_CHAT_QUERY = Query(default=None)
APPROVAL_STATUS_QUERY = Query(default=None, alias="status")
def _gateway_agent_id(agent: Agent) -> str:
@@ -87,6 +106,8 @@ def _gateway_agent_id(agent: Agent) -> str:
class SoulUpdateRequest(SQLModel):
"""Payload for updating an agent SOUL document."""
content: str
source_url: str | None = None
reason: str | None = None
@@ -124,9 +145,12 @@ async def _require_gateway_main(
session_key = (agent.openclaw_session_id or "").strip()
if not session_key:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Agent missing session key"
status_code=status.HTTP_403_FORBIDDEN,
detail="Agent missing session key",
)
gateway = await Gateway.objects.filter_by(main_session_key=session_key).first(session)
gateway = await Gateway.objects.filter_by(main_session_key=session_key).first(
session,
)
if gateway is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@@ -148,7 +172,9 @@ async def _require_gateway_board(
) -> Board:
board = await Board.objects.by_id(board_id).first(session)
if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Board not found",
)
if board.gateway_id != gateway.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return board
@@ -156,9 +182,10 @@ async def _require_gateway_board(
@router.get("/boards", response_model=DefaultLimitOffsetPage[BoardRead])
async def list_boards(
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> DefaultLimitOffsetPage[BoardRead]:
"""List boards visible to the authenticated agent."""
statement = select(Board)
if agent_ctx.agent.board_id:
statement = statement.where(col(Board.id) == agent_ctx.agent.board_id)
@@ -168,19 +195,21 @@ async def list_boards(
@router.get("/boards/{board_id}", response_model=BoardRead)
def get_board(
board: Board = Depends(get_board_or_404),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
board: Board = BOARD_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> Board:
"""Return a board if the authenticated agent can access it."""
_guard_board_access(agent_ctx, board)
return board
@router.get("/agents", response_model=DefaultLimitOffsetPage[AgentRead])
async def list_agents(
board_id: UUID | None = Query(default=None),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
board_id: UUID | None = BOARD_ID_QUERY,
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> DefaultLimitOffsetPage[AgentRead]:
"""List agents, optionally filtered to a board."""
statement = select(Agent)
if agent_ctx.agent.board_id:
if board_id and board_id != agent_ctx.agent.board_id:
@@ -188,13 +217,19 @@ async def list_agents(
statement = statement.where(Agent.board_id == agent_ctx.agent.board_id)
elif board_id:
statement = statement.where(Agent.board_id == board_id)
main_session_keys = await agents_api._get_gateway_main_session_keys(session)
get_gateway_main_session_keys = (
agents_api._get_gateway_main_session_keys # noqa: SLF001
)
to_agent_read = agents_api._to_agent_read # noqa: SLF001
with_computed_status = agents_api._with_computed_status # noqa: SLF001
main_session_keys = await get_gateway_main_session_keys(session)
statement = statement.order_by(col(Agent.created_at).desc())
def _transform(items: Sequence[Any]) -> Sequence[Any]:
agents = cast(Sequence[Agent], items)
return [
agents_api._to_agent_read(agents_api._with_computed_status(agent), main_session_keys)
to_agent_read(with_computed_status(agent), main_session_keys)
for agent in agents
]
@@ -202,14 +237,15 @@ async def list_agents(
@router.get("/boards/{board_id}/tasks", response_model=DefaultLimitOffsetPage[TaskRead])
async def list_tasks(
status_filter: str | None = Query(default=None, alias="status"),
async def list_tasks( # noqa: PLR0913
status_filter: str | None = TASK_STATUS_QUERY,
assigned_agent_id: UUID | None = None,
unassigned: bool | None = None,
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
board: Board = BOARD_DEP,
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> DefaultLimitOffsetPage[TaskRead]:
"""List tasks on a board with optional status and assignment filters."""
_guard_board_access(agent_ctx, board)
return await tasks_api.list_tasks(
status_filter=status_filter,
@@ -224,10 +260,11 @@ async def list_tasks(
@router.post("/boards/{board_id}/tasks", response_model=TaskRead)
async def create_task(
payload: TaskCreate,
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
board: Board = BOARD_DEP,
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> TaskRead:
"""Create a task on the board as the lead agent."""
_guard_board_access(agent_ctx, board)
if not agent_ctx.agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
@@ -250,7 +287,9 @@ async def create_task(
board_id=board.id,
dependency_ids=normalized_deps,
)
blocked_by = blocked_by_dependency_ids(dependency_ids=normalized_deps, status_by_id=dep_status)
blocked_by = blocked_by_dependency_ids(
dependency_ids=normalized_deps, status_by_id=dep_status,
)
if blocked_by and (task.assigned_agent_id is not None or task.status != "inbox"):
raise HTTPException(
@@ -280,7 +319,7 @@ async def create_task(
board_id=board.id,
task_id=task.id,
depends_on_task_id=dep_id,
)
),
)
await session.commit()
await session.refresh(task)
@@ -293,9 +332,14 @@ async def create_task(
)
await session.commit()
if task.assigned_agent_id:
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(
session,
)
if assigned_agent:
await tasks_api._notify_agent_on_task_assign(
notify_agent_on_task_assign = (
tasks_api._notify_agent_on_task_assign # noqa: SLF001
)
await notify_agent_on_task_assign(
session=session,
board=board,
task=task,
@@ -306,18 +350,23 @@ async def create_task(
"depends_on_task_ids": normalized_deps,
"blocked_by_task_ids": blocked_by,
"is_blocked": bool(blocked_by),
}
},
)
@router.patch("/boards/{board_id}/tasks/{task_id}", response_model=TaskRead)
async def update_task(
payload: TaskUpdate,
task: Task = Depends(get_task_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
task: Task = TASK_DEP,
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> TaskRead:
if agent_ctx.agent.board_id and task.board_id and agent_ctx.agent.board_id != task.board_id:
"""Update a task after board-level access checks."""
if (
agent_ctx.agent.board_id
and task.board_id
and agent_ctx.agent.board_id != task.board_id
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return await tasks_api.update_task(
payload=payload,
@@ -332,11 +381,16 @@ async def update_task(
response_model=DefaultLimitOffsetPage[TaskCommentRead],
)
async def list_task_comments(
task: Task = Depends(get_task_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
task: Task = TASK_DEP,
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> DefaultLimitOffsetPage[TaskCommentRead]:
if agent_ctx.agent.board_id and task.board_id and agent_ctx.agent.board_id != task.board_id:
"""List comments for a task visible to the authenticated agent."""
if (
agent_ctx.agent.board_id
and task.board_id
and agent_ctx.agent.board_id != task.board_id
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return await tasks_api.list_task_comments(
task=task,
@@ -344,14 +398,21 @@ async def list_task_comments(
)
@router.post("/boards/{board_id}/tasks/{task_id}/comments", response_model=TaskCommentRead)
@router.post(
"/boards/{board_id}/tasks/{task_id}/comments", response_model=TaskCommentRead,
)
async def create_task_comment(
payload: TaskCommentCreate,
task: Task = Depends(get_task_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
task: Task = TASK_DEP,
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> ActivityEvent:
if agent_ctx.agent.board_id and task.board_id and agent_ctx.agent.board_id != task.board_id:
"""Create a task comment on behalf of the authenticated agent."""
if (
agent_ctx.agent.board_id
and task.board_id
and agent_ctx.agent.board_id != task.board_id
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return await tasks_api.create_task_comment(
payload=payload,
@@ -361,13 +422,16 @@ async def create_task_comment(
)
@router.get("/boards/{board_id}/memory", response_model=DefaultLimitOffsetPage[BoardMemoryRead])
@router.get(
"/boards/{board_id}/memory", response_model=DefaultLimitOffsetPage[BoardMemoryRead],
)
async def list_board_memory(
is_chat: bool | None = Query(default=None),
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
is_chat: bool | None = IS_CHAT_QUERY,
board: Board = BOARD_DEP,
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> DefaultLimitOffsetPage[BoardMemoryRead]:
"""List board memory entries with optional chat filtering."""
_guard_board_access(agent_ctx, board)
return await board_memory_api.list_board_memory(
is_chat=is_chat,
@@ -380,10 +444,11 @@ async def list_board_memory(
@router.post("/boards/{board_id}/memory", response_model=BoardMemoryRead)
async def create_board_memory(
payload: BoardMemoryCreate,
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
board: Board = BOARD_DEP,
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> BoardMemory:
"""Create a board memory entry."""
_guard_board_access(agent_ctx, board)
return await board_memory_api.create_board_memory(
payload=payload,
@@ -398,11 +463,12 @@ async def create_board_memory(
response_model=DefaultLimitOffsetPage[ApprovalRead],
)
async def list_approvals(
status_filter: ApprovalStatus | None = Query(default=None, alias="status"),
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
status_filter: ApprovalStatus | None = APPROVAL_STATUS_QUERY,
board: Board = BOARD_DEP,
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> DefaultLimitOffsetPage[ApprovalRead]:
"""List approvals for a board."""
_guard_board_access(agent_ctx, board)
return await approvals_api.list_approvals(
status_filter=status_filter,
@@ -415,10 +481,11 @@ async def list_approvals(
@router.post("/boards/{board_id}/approvals", response_model=ApprovalRead)
async def create_approval(
payload: ApprovalCreate,
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
board: Board = BOARD_DEP,
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> Approval:
"""Create a board approval request."""
_guard_board_access(agent_ctx, board)
return await approvals_api.create_approval(
payload=payload,
@@ -431,10 +498,11 @@ async def create_approval(
@router.post("/boards/{board_id}/onboarding", response_model=BoardOnboardingRead)
async def update_onboarding(
payload: BoardOnboardingAgentUpdate,
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
board: Board = BOARD_DEP,
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> BoardOnboardingSession:
"""Apply onboarding updates for a board."""
_guard_board_access(agent_ctx, board)
return await onboarding_api.agent_onboarding_update(
payload=payload,
@@ -447,14 +515,17 @@ async def update_onboarding(
@router.post("/agents", response_model=AgentRead)
async def create_agent(
payload: AgentCreate,
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> AgentRead:
"""Create an agent on the caller's board."""
if not agent_ctx.agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if not agent_ctx.agent.board_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
payload = AgentCreate(**{**payload.model_dump(), "board_id": agent_ctx.agent.board_id})
payload = AgentCreate(
**{**payload.model_dump(), "board_id": agent_ctx.agent.board_id},
)
return await agents_api.create_agent(
payload=payload,
session=session,
@@ -466,10 +537,11 @@ async def create_agent(
async def nudge_agent(
payload: AgentNudge,
agent_id: str,
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
board: Board = BOARD_DEP,
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> OkResponse:
"""Send a direct nudge message to a board agent."""
_guard_board_access(agent_ctx, board)
if not agent_ctx.agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
@@ -484,7 +556,9 @@ async def nudge_agent(
message = payload.message
config = await _gateway_config(session, board)
try:
await ensure_session(target.openclaw_session_id, config=config, label=target.name)
await ensure_session(
target.openclaw_session_id, config=config, label=target.name,
)
await send_message(
message,
session_key=target.openclaw_session_id,
@@ -499,7 +573,9 @@ async def nudge_agent(
agent_id=agent_ctx.agent.id,
)
await session.commit()
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
) from exc
record_activity(
session,
event_type="agent.nudge.sent",
@@ -513,9 +589,10 @@ async def nudge_agent(
@router.post("/heartbeat", response_model=AgentRead)
async def agent_heartbeat(
payload: AgentHeartbeatCreate,
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> AgentRead:
"""Record heartbeat status for the authenticated agent."""
# Heartbeats must apply to the authenticated agent; agent names are not unique.
return await agents_api.heartbeat_agent(
agent_id=str(agent_ctx.agent.id),
@@ -528,10 +605,11 @@ async def agent_heartbeat(
@router.get("/boards/{board_id}/agents/{agent_id}/soul", response_model=str)
async def get_agent_soul(
agent_id: str,
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
board: Board = BOARD_DEP,
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> str:
"""Fetch the target agent's SOUL.md content from the gateway."""
_guard_board_access(agent_ctx, board)
if not agent_ctx.agent.is_board_lead and str(agent_ctx.agent.id) != agent_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
@@ -547,7 +625,9 @@ async def get_agent_soul(
config=config,
)
except OpenClawGatewayError as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
) from exc
if isinstance(payload, str):
return payload
if isinstance(payload, dict):
@@ -559,17 +639,20 @@ async def get_agent_soul(
nested = file_obj.get("content")
if isinstance(nested, str):
return nested
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="Invalid gateway response")
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail="Invalid gateway response",
)
@router.put("/boards/{board_id}/agents/{agent_id}/soul", response_model=OkResponse)
async def update_agent_soul(
agent_id: str,
payload: SoulUpdateRequest,
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
board: Board = BOARD_DEP,
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> OkResponse:
"""Update an agent's SOUL.md content in DB and gateway."""
_guard_board_access(agent_ctx, board)
if not agent_ctx.agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
@@ -597,7 +680,9 @@ async def update_agent_soul(
config=config,
)
except OpenClawGatewayError as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
) from exc
reason = (payload.reason or "").strip()
source_url = (payload.source_url or "").strip()
note = f"SOUL.md updated for {target.name}."
@@ -621,10 +706,11 @@ async def update_agent_soul(
)
async def ask_user_via_gateway_main(
payload: GatewayMainAskUserRequest,
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
board: Board = BOARD_DEP,
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> GatewayMainAskUserResponse:
"""Route a lead's ask-user request through the gateway main agent."""
import json
_guard_board_access(agent_ctx, board)
@@ -653,7 +739,9 @@ async def ask_user_via_gateway_main(
correlation = payload.correlation_id.strip() if payload.correlation_id else ""
correlation_line = f"Correlation ID: {correlation}\n" if correlation else ""
preferred_channel = (payload.preferred_channel or "").strip()
channel_line = f"Preferred channel: {preferred_channel}\n" if preferred_channel else ""
channel_line = (
f"Preferred channel: {preferred_channel}\n" if preferred_channel else ""
)
tags = payload.reply_tags or ["gateway_main", "user_reply"]
tags_json = json.dumps(tags)
@@ -668,9 +756,12 @@ async def ask_user_via_gateway_main(
f"{correlation_line}"
f"{channel_line}\n"
f"{payload.content.strip()}\n\n"
"Please reach the user via your configured OpenClaw channel(s) (Slack/SMS/etc).\n"
"If you cannot reach them there, post the question in Mission Control board chat as a fallback.\n\n"
"When you receive the answer, reply in Mission Control by writing a NON-chat memory item on this board:\n"
"Please reach the user via your configured OpenClaw channel(s) "
"(Slack/SMS/etc).\n"
"If you cannot reach them there, post the question in Mission Control "
"board chat as a fallback.\n\n"
"When you receive the answer, reply in Mission Control by writing a "
"NON-chat memory item on this board:\n"
f"POST {base_url}/api/v1/agent/boards/{board.id}/memory\n"
f'Body: {{"content":"<answer>","tags":{tags_json},"source":"{reply_source}"}}\n'
"Do NOT reply in OpenClaw chat."
@@ -678,7 +769,9 @@ async def ask_user_via_gateway_main(
try:
await ensure_session(main_session_key, config=config, label="Main Agent")
await send_message(message, session_key=main_session_key, config=config, deliver=True)
await send_message(
message, session_key=main_session_key, config=config, deliver=True,
)
except OpenClawGatewayError as exc:
record_activity(
session,
@@ -687,7 +780,9 @@ async def ask_user_via_gateway_main(
agent_id=agent_ctx.agent.id,
)
await session.commit()
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
) from exc
record_activity(
session,
@@ -696,7 +791,9 @@ async def ask_user_via_gateway_main(
agent_id=agent_ctx.agent.id,
)
main_agent = await Agent.objects.filter_by(openclaw_session_id=main_session_key).first(session)
main_agent = await Agent.objects.filter_by(
openclaw_session_id=main_session_key,
).first(session)
await session.commit()
@@ -714,9 +811,10 @@ async def ask_user_via_gateway_main(
async def message_gateway_board_lead(
board_id: UUID,
payload: GatewayLeadMessageRequest,
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> GatewayLeadMessageResponse:
"""Send a gateway-main message to a single board lead agent."""
import json
gateway, config = await _require_gateway_main(session, agent_ctx.agent)
@@ -736,7 +834,11 @@ async def message_gateway_board_lead(
)
base_url = settings.base_url or "http://localhost:8000"
header = "GATEWAY MAIN QUESTION" if payload.kind == "question" else "GATEWAY MAIN HANDOFF"
header = (
"GATEWAY MAIN QUESTION"
if payload.kind == "question"
else "GATEWAY MAIN HANDOFF"
)
correlation = payload.correlation_id.strip() if payload.correlation_id else ""
correlation_line = f"Correlation ID: {correlation}\n" if correlation else ""
tags = payload.reply_tags or ["gateway_main", "lead_reply"]
@@ -767,7 +869,9 @@ async def message_gateway_board_lead(
agent_id=agent_ctx.agent.id,
)
await session.commit()
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
) from exc
record_activity(
session,
@@ -791,9 +895,10 @@ async def message_gateway_board_lead(
)
async def broadcast_gateway_lead_message(
payload: GatewayLeadBroadcastRequest,
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> GatewayLeadBroadcastResponse:
"""Broadcast a gateway-main message to multiple board leads."""
import json
gateway, config = await _require_gateway_main(session, agent_ctx.agent)
@@ -808,7 +913,11 @@ async def broadcast_gateway_lead_message(
boards = list(await session.exec(statement))
base_url = settings.base_url or "http://localhost:8000"
header = "GATEWAY MAIN QUESTION" if payload.kind == "question" else "GATEWAY MAIN HANDOFF"
header = (
"GATEWAY MAIN QUESTION"
if payload.kind == "question"
else "GATEWAY MAIN HANDOFF"
)
correlation = payload.correlation_id.strip() if payload.correlation_id else ""
correlation_line = f"Correlation ID: {correlation}\n" if correlation else ""
tags = payload.reply_tags or ["gateway_main", "lead_reply"]
@@ -819,7 +928,7 @@ async def broadcast_gateway_lead_message(
sent = 0
failed = 0
for board in boards:
async def _send_to_board(board: Board) -> GatewayLeadBroadcastBoardResult:
try:
lead, _lead_created = await ensure_board_lead_agent(
session,
@@ -837,30 +946,34 @@ async def broadcast_gateway_lead_message(
f"From agent: {agent_ctx.agent.name}\n"
f"{correlation_line}\n"
f"{payload.content.strip()}\n\n"
"Reply to the gateway main by writing a NON-chat memory item on this board:\n"
"Reply to the gateway main by writing a NON-chat memory item "
"on this board:\n"
f"POST {base_url}/api/v1/agent/boards/{board.id}/memory\n"
f'Body: {{"content":"...","tags":{tags_json},"source":"{reply_source}"}}\n'
f'Body: {{"content":"...","tags":{tags_json},'
f'"source":"{reply_source}"}}\n'
"Do NOT reply in OpenClaw chat."
)
await ensure_session(lead_session_key, config=config, label=lead.name)
await send_message(message, session_key=lead_session_key, config=config)
results.append(
GatewayLeadBroadcastBoardResult(
board_id=board.id,
lead_agent_id=lead.id,
lead_agent_name=lead.name,
ok=True,
)
return GatewayLeadBroadcastBoardResult(
board_id=board.id,
lead_agent_id=lead.id,
lead_agent_name=lead.name,
ok=True,
)
sent += 1
except (HTTPException, OpenClawGatewayError, ValueError) as exc:
results.append(
GatewayLeadBroadcastBoardResult(
board_id=board.id,
ok=False,
error=str(exc),
)
return GatewayLeadBroadcastBoardResult(
board_id=board.id,
ok=False,
error=str(exc),
)
for board in boards:
board_result = await _send_to_board(board)
results.append(board_result)
if board_result.ok:
sent += 1
else:
failed += 1
record_activity(

View File

@@ -1,3 +1,5 @@
"""Agent lifecycle, listing, heartbeat, and deletion API endpoints."""
from __future__ import annotations
import asyncio
@@ -5,14 +7,12 @@ import json
import re
from collections.abc import AsyncIterator, Sequence
from datetime import datetime, timedelta, timezone
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast
from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import asc, or_
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sse_starlette.sse import EventSourceResponse
from app.api.deps import ActorContext, require_admin_or_agent, require_org_admin
@@ -23,14 +23,17 @@ from app.db import crud
from app.db.pagination import paginate
from app.db.session import async_session_maker, get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
from app.integrations.openclaw_gateway import (
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.activity_events import ActivityEvent
from app.models.agents import Agent
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.organizations import Organization
from app.models.tasks import Task
from app.models.users import User
from app.schemas.agents import (
AgentCreate,
AgentHeartbeat,
@@ -56,10 +59,23 @@ from app.services.organizations import (
require_board_access,
)
if TYPE_CHECKING:
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.users import User
router = APIRouter(prefix="/agents", tags=["agents"])
OFFLINE_AFTER = timedelta(minutes=10)
AGENT_SESSION_PREFIX = "agent"
BOARD_ID_QUERY = Query(default=None)
GATEWAY_ID_QUERY = Query(default=None)
SINCE_QUERY = Query(default=None)
SESSION_DEP = Depends(get_session)
ORG_ADMIN_DEP = Depends(require_org_admin)
ACTOR_DEP = Depends(require_admin_or_agent)
AUTH_DEP = Depends(get_auth_context)
def _parse_since(value: str | None) -> datetime | None:
@@ -111,14 +127,16 @@ async def _require_board(
)
board = await Board.objects.by_id(board_id).first(session)
if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Board not found",
)
if user is not None:
await require_board_access(session, user=user, board=board, write=write)
return board
async def _require_gateway(
session: AsyncSession, board: Board
session: AsyncSession, board: Board,
) -> tuple[Gateway, GatewayClientConfig]:
if not board.gateway_id:
raise HTTPException(
@@ -169,16 +187,20 @@ async def _get_gateway_main_session_keys(session: AsyncSession) -> set[str]:
def _is_gateway_main(agent: Agent, main_session_keys: set[str]) -> bool:
return bool(agent.openclaw_session_id and agent.openclaw_session_id in main_session_keys)
return bool(
agent.openclaw_session_id and agent.openclaw_session_id in main_session_keys,
)
def _to_agent_read(agent: Agent, main_session_keys: set[str]) -> AgentRead:
model = AgentRead.model_validate(agent, from_attributes=True)
return model.model_copy(update={"is_gateway_main": _is_gateway_main(agent, main_session_keys)})
return model.model_copy(
update={"is_gateway_main": _is_gateway_main(agent, main_session_keys)},
)
async def _find_gateway_for_main_session(
session: AsyncSession, session_key: str | None
session: AsyncSession, session_key: str | None,
) -> Gateway | None:
if not session_key:
return None
@@ -210,7 +232,9 @@ def _with_computed_status(agent: Agent) -> Agent:
def _serialize_agent(agent: Agent, main_session_keys: set[str]) -> dict[str, object]:
return _to_agent_read(_with_computed_status(agent), main_session_keys).model_dump(mode="json")
return _to_agent_read(_with_computed_status(agent), main_session_keys).model_dump(
mode="json",
)
async def _fetch_agent_events(
@@ -225,18 +249,22 @@ async def _fetch_agent_events(
or_(
col(Agent.updated_at) >= since,
col(Agent.last_seen_at) >= since,
)
),
).order_by(asc(col(Agent.updated_at)))
return list(await session.exec(statement))
async def _require_user_context(session: AsyncSession, user: User | None) -> OrganizationContext:
async def _require_user_context(
session: AsyncSession, user: User | None,
) -> OrganizationContext:
if user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
member = await get_active_membership(session, user)
if member is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
organization = await Organization.objects.by_id(member.organization_id).first(session)
organization = await Organization.objects.by_id(member.organization_id).first(
session,
)
if organization is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return OrganizationContext(organization=organization, member=member)
@@ -252,7 +280,9 @@ async def _require_agent_access(
if agent.board_id is None:
if not is_org_admin(ctx.member):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
gateway = await _find_gateway_for_main_session(session, agent.openclaw_session_id)
gateway = await _find_gateway_for_main_session(
session, agent.openclaw_session_id,
)
if gateway is None or gateway.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return
@@ -274,7 +304,7 @@ def _record_heartbeat(session: AsyncSession, agent: Agent) -> None:
def _record_instruction_failure(
session: AsyncSession, agent: Agent, error: str, action: str
session: AsyncSession, agent: Agent, error: str, action: str,
) -> None:
action_label = action.replace("_", " ").capitalize()
record_activity(
@@ -286,7 +316,7 @@ def _record_instruction_failure(
async def _send_wakeup_message(
agent: Agent, config: GatewayClientConfig, verb: str = "provisioned"
agent: Agent, config: GatewayClientConfig, verb: str = "provisioned",
) -> None:
session_key = agent.openclaw_session_id or _build_session_key(agent.name)
await ensure_session(session_key, config=config, label=agent.name)
@@ -300,11 +330,12 @@ async def _send_wakeup_message(
@router.get("", response_model=DefaultLimitOffsetPage[AgentRead])
async def list_agents(
board_id: UUID | None = Query(default=None),
gateway_id: UUID | None = Query(default=None),
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
board_id: UUID | None = BOARD_ID_QUERY,
gateway_id: UUID | None = GATEWAY_ID_QUERY,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> DefaultLimitOffsetPage[AgentRead]:
"""List agents visible to the active organization admin."""
main_session_keys = await _get_gateway_main_session_keys(session)
board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False)
if board_id is not None and board_id not in set(board_ids):
@@ -315,9 +346,11 @@ async def list_agents(
base_filter: ColumnElement[bool] = col(Agent.board_id).in_(board_ids)
if is_org_admin(ctx.member):
gateway_keys = select(Gateway.main_session_key).where(
col(Gateway.organization_id) == ctx.organization.id
col(Gateway.organization_id) == ctx.organization.id,
)
base_filter = or_(
base_filter, col(Agent.openclaw_session_id).in_(gateway_keys),
)
base_filter = or_(base_filter, col(Agent.openclaw_session_id).in_(gateway_keys))
statement = select(Agent).where(base_filter)
if board_id is not None:
statement = statement.where(col(Agent.board_id) == board_id)
@@ -326,13 +359,16 @@ async def list_agents(
if gateway is None or gateway.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
statement = statement.join(Board, col(Agent.board_id) == col(Board.id)).where(
col(Board.gateway_id) == gateway_id
col(Board.gateway_id) == gateway_id,
)
statement = statement.order_by(col(Agent.created_at).desc())
def _transform(items: Sequence[Any]) -> Sequence[Any]:
agents = cast(Sequence[Agent], items)
return [_to_agent_read(_with_computed_status(agent), main_session_keys) for agent in agents]
return [
_to_agent_read(_with_computed_status(agent), main_session_keys)
for agent in agents
]
return await paginate(session, statement, transformer=_transform)
@@ -340,11 +376,12 @@ async def list_agents(
@router.get("/stream")
async def stream_agents(
request: Request,
board_id: UUID | None = Query(default=None),
since: str | None = Query(default=None),
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
board_id: UUID | None = BOARD_ID_QUERY,
since: str | None = SINCE_QUERY,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> EventSourceResponse:
"""Stream agent updates as SSE events."""
since_dt = _parse_since(since) or utcnow()
last_seen = since_dt
board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False)
@@ -359,14 +396,20 @@ async def stream_agents(
break
async with async_session_maker() as stream_session:
if board_id is not None:
agents = await _fetch_agent_events(stream_session, board_id, last_seen)
agents = await _fetch_agent_events(
stream_session, board_id, last_seen,
)
elif allowed_ids:
agents = await _fetch_agent_events(stream_session, None, last_seen)
agents = [agent for agent in agents if agent.board_id in allowed_ids]
agents = [
agent for agent in agents if agent.board_id in allowed_ids
]
else:
agents = []
main_session_keys = (
await _get_gateway_main_session_keys(stream_session) if agents else set()
await _get_gateway_main_session_keys(stream_session)
if agents
else set()
)
for agent in agents:
updated_at = agent.updated_at or agent.last_seen_at or utcnow()
@@ -379,11 +422,12 @@ async def stream_agents(
@router.post("", response_model=AgentRead)
async def create_agent(
async def create_agent( # noqa: C901, PLR0912, PLR0915
payload: AgentCreate,
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
session: AsyncSession = SESSION_DEP,
actor: ActorContext = ACTOR_DEP,
) -> AgentRead:
"""Create and provision an agent."""
if actor.actor_type == "user":
ctx = await _require_user_context(session, actor.user)
if not is_org_admin(ctx.member):
@@ -404,7 +448,9 @@ async def create_agent(
status_code=status.HTTP_403_FORBIDDEN,
detail="Board leads can only create agents in their own board",
)
payload = AgentCreate(**{**payload.model_dump(), "board_id": actor.agent.board_id})
payload = AgentCreate(
**{**payload.model_dump(), "board_id": actor.agent.board_id},
)
board = await _require_board(
session,
@@ -420,7 +466,7 @@ async def create_agent(
await session.exec(
select(Agent)
.where(Agent.board_id == board.id)
.where(col(Agent.name).ilike(requested_name))
.where(col(Agent.name).ilike(requested_name)),
)
).first()
if existing:
@@ -428,20 +474,23 @@ async def create_agent(
status_code=status.HTTP_409_CONFLICT,
detail="An agent with this name already exists on this board.",
)
# Prevent OpenClaw session/workspace collisions by enforcing uniqueness within
# the gateway workspace too (agents on other boards share the same gateway root).
# Prevent session/workspace collisions inside the gateway workspace.
# Agents on different boards can still share one gateway root.
existing_gateway = (
await session.exec(
select(Agent)
.join(Board, col(Agent.board_id) == col(Board.id))
.where(col(Board.gateway_id) == gateway.id)
.where(col(Agent.name).ilike(requested_name))
.where(col(Agent.name).ilike(requested_name)),
)
).first()
if existing_gateway:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="An agent with this name already exists in this gateway workspace.",
detail=(
"An agent with this name already exists in this gateway "
"workspace."
),
)
desired_session_key = _build_session_key(requested_name)
existing_session_key = (
@@ -449,13 +498,16 @@ async def create_agent(
select(Agent)
.join(Board, col(Agent.board_id) == col(Board.id))
.where(col(Board.gateway_id) == gateway.id)
.where(col(Agent.openclaw_session_id) == desired_session_key)
.where(col(Agent.openclaw_session_id) == desired_session_key),
)
).first()
if existing_session_key:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="This agent name would collide with an existing workspace session key. Pick a different name.",
detail=(
"This agent name would collide with an existing workspace "
"session key. Pick a different name."
),
)
agent = Agent.model_validate(data)
agent.status = "provisioning"
@@ -465,7 +517,9 @@ async def create_agent(
agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy()
agent.provision_requested_at = utcnow()
agent.provision_action = "provision"
session_key, session_error = await _ensure_gateway_session(agent.name, client_config)
session_key, session_error = await _ensure_gateway_session(
agent.name, client_config,
)
agent.openclaw_session_id = session_key
session.add(agent)
await session.commit()
@@ -527,9 +581,10 @@ async def create_agent(
@router.get("/{agent_id}", response_model=AgentRead)
async def get_agent(
agent_id: str,
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> AgentRead:
"""Get a single agent by id."""
agent = await Agent.objects.by_id(agent_id).first(session)
if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -539,14 +594,16 @@ async def get_agent(
@router.patch("/{agent_id}", response_model=AgentRead)
async def update_agent(
async def update_agent( # noqa: C901, PLR0912, PLR0913, PLR0915
agent_id: str,
payload: AgentUpdate,
*,
force: bool = False,
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
ctx: OrganizationContext = Depends(require_org_admin),
session: AsyncSession = SESSION_DEP,
auth: AuthContext = AUTH_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> AgentRead:
"""Update agent metadata and optionally reprovision."""
agent = await Agent.objects.by_id(agent_id).first(session)
if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -564,12 +621,16 @@ async def update_agent(
new_board = await _require_board(session, updates["board_id"])
if new_board.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if not await has_board_access(session, member=ctx.member, board=new_board, write=True):
if not await has_board_access(
session, member=ctx.member, board=new_board, write=True,
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if not updates and not force and make_main is None:
main_session_keys = await _get_gateway_main_session_keys(session)
return _to_agent_read(_with_computed_status(agent), main_session_keys)
main_gateway = await _find_gateway_for_main_session(session, agent.openclaw_session_id)
main_gateway = await _find_gateway_for_main_session(
session, agent.openclaw_session_id,
)
gateway_for_main: Gateway | None = None
if make_main is True:
board_source = updates.get("board_id") or agent.board_id
@@ -723,9 +784,10 @@ async def update_agent(
async def heartbeat_agent(
agent_id: str,
payload: AgentHeartbeat,
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
session: AsyncSession = SESSION_DEP,
actor: ActorContext = ACTOR_DEP,
) -> AgentRead:
"""Record a heartbeat for a specific agent."""
agent = await Agent.objects.by_id(agent_id).first(session)
if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -751,12 +813,14 @@ async def heartbeat_agent(
@router.post("/heartbeat", response_model=AgentRead)
async def heartbeat_or_create_agent(
async def heartbeat_or_create_agent( # noqa: C901, PLR0912, PLR0915
payload: AgentHeartbeatCreate,
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
session: AsyncSession = SESSION_DEP,
actor: ActorContext = ACTOR_DEP,
) -> AgentRead:
# Agent tokens must heartbeat their authenticated agent record. Names are not unique.
"""Heartbeat an existing agent or create/provision one if needed."""
# Agent tokens must heartbeat their authenticated agent record.
# Names are not unique.
if actor.actor_type == "agent" and actor.agent:
return await heartbeat_agent(
agent_id=str(actor.agent.id),
@@ -793,7 +857,9 @@ async def heartbeat_or_create_agent(
agent.agent_token_hash = hash_agent_token(raw_token)
agent.provision_requested_at = utcnow()
agent.provision_action = "provision"
session_key, session_error = await _ensure_gateway_session(agent.name, client_config)
session_key, session_error = await _ensure_gateway_session(
agent.name, client_config,
)
agent.openclaw_session_id = session_key
session.add(agent)
await session.commit()
@@ -814,7 +880,9 @@ async def heartbeat_or_create_agent(
)
await session.commit()
try:
await provision_agent(agent, board, gateway, raw_token, actor.user, action="provision")
await provision_agent(
agent, board, gateway, raw_token, actor.user, action="provision",
)
await _send_wakeup_message(agent, client_config, verb="provisioned")
agent.provision_confirm_token_hash = None
agent.provision_requested_at = None
@@ -864,7 +932,7 @@ async def heartbeat_or_create_agent(
)
gateway, client_config = await _require_gateway(session, board)
await provision_agent(
agent, board, gateway, raw_token, actor.user, action="provision"
agent, board, gateway, raw_token, actor.user, action="provision",
)
await _send_wakeup_message(agent, client_config, verb="provisioned")
agent.provision_confirm_token_hash = None
@@ -903,7 +971,9 @@ async def heartbeat_or_create_agent(
write=actor.actor_type == "user",
)
gateway, client_config = await _require_gateway(session, board)
session_key, session_error = await _ensure_gateway_session(agent.name, client_config)
session_key, session_error = await _ensure_gateway_session(
agent.name, client_config,
)
agent.openclaw_session_id = session_key
if session_error:
record_activity(
@@ -937,15 +1007,18 @@ async def heartbeat_or_create_agent(
@router.delete("/{agent_id}", response_model=OkResponse)
async def delete_agent(
agent_id: str,
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> OkResponse:
"""Delete an agent and clean related task state."""
agent = await Agent.objects.by_id(agent_id).first(session)
if agent is None:
return OkResponse()
await _require_agent_access(session, agent=agent, ctx=ctx, write=True)
board = await _require_board(session, str(agent.board_id) if agent.board_id else None)
board = await _require_board(
session, str(agent.board_id) if agent.board_id else None,
)
gateway, client_config = await _require_gateway(session, board)
try:
workspace_path = await cleanup_agent(agent, gateway)
@@ -970,7 +1043,7 @@ async def delete_agent(
message=f"Deleted agent {agent.name}.",
agent_id=None,
)
now = datetime.now()
now = utcnow()
await crud.update_where(
session,
Task,

View File

@@ -1,3 +1,5 @@
"""Authentication bootstrap endpoints for the Mission Control API."""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, status
@@ -6,10 +8,12 @@ from app.core.auth import AuthContext, get_auth_context
from app.schemas.users import UserRead
router = APIRouter(prefix="/auth", tags=["auth"])
AUTH_CONTEXT_DEP = Depends(get_auth_context)
@router.post("/bootstrap", response_model=UserRead)
async def bootstrap_user(auth: AuthContext = Depends(get_auth_context)) -> UserRead:
async def bootstrap_user(auth: AuthContext = AUTH_CONTEXT_DEP) -> UserRead:
"""Return the authenticated user profile from token claims."""
if auth.actor_type != "user" or auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
return UserRead.model_validate(auth.user)

View File

@@ -1,13 +1,16 @@
"""Board onboarding endpoints for user/agent collaboration."""
# ruff: noqa: E501
from __future__ import annotations
import logging
import re
from typing import TYPE_CHECKING
from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import ValidationError
from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import (
ActorContext,
@@ -18,15 +21,17 @@ from app.api.deps import (
require_admin_or_agent,
)
from app.core.agent_tokens import generate_agent_token, hash_agent_token
from app.core.auth import AuthContext
from app.core.config import settings
from app.core.time import utcnow
from app.db.session import get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
from app.integrations.openclaw_gateway import (
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.agents import Agent
from app.models.board_onboarding import BoardOnboardingSession
from app.models.boards import Board
from app.models.gateways import Gateway
from app.schemas.board_onboarding import (
BoardOnboardingAgentComplete,
@@ -41,12 +46,24 @@ from app.schemas.board_onboarding import (
from app.schemas.boards import BoardRead
from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_agent
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.auth import AuthContext
from app.models.boards import Board
router = APIRouter(prefix="/boards/{board_id}/onboarding", tags=["board-onboarding"])
logger = logging.getLogger(__name__)
BOARD_USER_READ_DEP = Depends(get_board_for_user_read)
BOARD_USER_WRITE_DEP = Depends(get_board_for_user_write)
BOARD_OR_404_DEP = Depends(get_board_or_404)
SESSION_DEP = Depends(get_session)
ACTOR_DEP = Depends(require_admin_or_agent)
ADMIN_AUTH_DEP = Depends(require_admin_auth)
async def _gateway_config(
session: AsyncSession, board: Board
session: AsyncSession, board: Board,
) -> tuple[Gateway, GatewayClientConfig]:
if not board.gateway_id:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
@@ -61,7 +78,7 @@ def _build_session_key(agent_name: str) -> str:
return f"agent:{slug or uuid4().hex}:main"
def _lead_agent_name(board: Board) -> str:
def _lead_agent_name(_board: Board) -> str:
return "Lead Agent"
@@ -69,7 +86,7 @@ def _lead_session_key(board: Board) -> str:
return f"agent:lead-{board.id}:main"
async def _ensure_lead_agent(
async def _ensure_lead_agent( # noqa: PLR0913
session: AsyncSession,
board: Board,
gateway: Gateway,
@@ -100,7 +117,11 @@ async def _ensure_lead_agent(
}
if identity_profile:
merged_identity_profile.update(
{key: value.strip() for key, value in identity_profile.items() if value.strip()}
{
key: value.strip()
for key, value in identity_profile.items()
if value.strip()
},
)
agent = Agent(
@@ -121,7 +142,9 @@ async def _ensure_lead_agent(
await session.refresh(agent)
try:
await provision_agent(agent, board, gateway, raw_token, auth.user, action="provision")
await provision_agent(
agent, board, gateway, raw_token, auth.user, action="provision",
)
await ensure_session(agent.openclaw_session_id, config=config, label=agent.name)
await send_message(
(
@@ -141,9 +164,10 @@ async def _ensure_lead_agent(
@router.get("", response_model=BoardOnboardingRead)
async def get_onboarding(
board: Board = Depends(get_board_for_user_read),
session: AsyncSession = Depends(get_session),
board: Board = BOARD_USER_READ_DEP,
session: AsyncSession = SESSION_DEP,
) -> BoardOnboardingSession:
"""Get the latest onboarding session for a board."""
onboarding = (
await BoardOnboardingSession.objects.filter_by(board_id=board.id)
.order_by(col(BoardOnboardingSession.updated_at).desc())
@@ -156,10 +180,11 @@ async def get_onboarding(
@router.post("/start", response_model=BoardOnboardingRead)
async def start_onboarding(
payload: BoardOnboardingStart,
board: Board = Depends(get_board_for_user_write),
session: AsyncSession = Depends(get_session),
_payload: BoardOnboardingStart,
board: Board = BOARD_USER_WRITE_DEP,
session: AsyncSession = SESSION_DEP,
) -> BoardOnboardingSession:
"""Start onboarding and send instructions to the gateway main agent."""
onboarding = (
await BoardOnboardingSession.objects.filter_by(board_id=board.id)
.filter(col(BoardOnboardingSession.status) == "active")
@@ -219,15 +244,21 @@ async def start_onboarding(
try:
await ensure_session(session_key, config=config, label="Main Agent")
await send_message(prompt, session_key=session_key, config=config, deliver=False)
await send_message(
prompt, session_key=session_key, config=config, deliver=False,
)
except OpenClawGatewayError as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
) from exc
onboarding = BoardOnboardingSession(
board_id=board.id,
session_key=session_key,
status="active",
messages=[{"role": "user", "content": prompt, "timestamp": utcnow().isoformat()}],
messages=[
{"role": "user", "content": prompt, "timestamp": utcnow().isoformat()},
],
)
session.add(onboarding)
await session.commit()
@@ -238,9 +269,10 @@ async def start_onboarding(
@router.post("/answer", response_model=BoardOnboardingRead)
async def answer_onboarding(
payload: BoardOnboardingAnswer,
board: Board = Depends(get_board_for_user_write),
session: AsyncSession = Depends(get_session),
board: Board = BOARD_USER_WRITE_DEP,
session: AsyncSession = SESSION_DEP,
) -> BoardOnboardingSession:
"""Send a user onboarding answer to the gateway main agent."""
onboarding = (
await BoardOnboardingSession.objects.filter_by(board_id=board.id)
.order_by(col(BoardOnboardingSession.updated_at).desc())
@@ -255,15 +287,22 @@ async def answer_onboarding(
answer_text = f"{payload.answer}: {payload.other_text}"
messages = list(onboarding.messages or [])
messages.append({"role": "user", "content": answer_text, "timestamp": utcnow().isoformat()})
messages.append(
{"role": "user", "content": answer_text, "timestamp": utcnow().isoformat()},
)
try:
await ensure_session(onboarding.session_key, config=config, label="Main Agent")
await send_message(
answer_text, session_key=onboarding.session_key, config=config, deliver=False
answer_text,
session_key=onboarding.session_key,
config=config,
deliver=False,
)
except OpenClawGatewayError as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
) from exc
onboarding.messages = messages
onboarding.updated_at = utcnow()
@@ -276,10 +315,11 @@ async def answer_onboarding(
@router.post("/agent", response_model=BoardOnboardingRead)
async def agent_onboarding_update(
payload: BoardOnboardingAgentUpdate,
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
board: Board = BOARD_OR_404_DEP,
session: AsyncSession = SESSION_DEP,
actor: ActorContext = ACTOR_DEP,
) -> BoardOnboardingSession:
"""Store onboarding updates submitted by the gateway main agent."""
if actor.actor_type != "agent" or actor.agent is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
agent = actor.agent
@@ -288,9 +328,13 @@ async def agent_onboarding_update(
if board.gateway_id:
gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway and gateway.main_session_key and agent.openclaw_session_id:
if agent.openclaw_session_id != gateway.main_session_key:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if (
gateway
and gateway.main_session_key
and agent.openclaw_session_id
and agent.openclaw_session_id != gateway.main_session_key
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
onboarding = (
await BoardOnboardingSession.objects.filter_by(board_id=board.id)
@@ -315,9 +359,13 @@ async def agent_onboarding_update(
if isinstance(payload, BoardOnboardingAgentComplete):
onboarding.draft_goal = payload_data
onboarding.status = "completed"
messages.append({"role": "assistant", "content": payload_text, "timestamp": now})
messages.append(
{"role": "assistant", "content": payload_text, "timestamp": now},
)
else:
messages.append({"role": "assistant", "content": payload_text, "timestamp": now})
messages.append(
{"role": "assistant", "content": payload_text, "timestamp": now},
)
onboarding.messages = messages
onboarding.updated_at = utcnow()
@@ -334,12 +382,13 @@ async def agent_onboarding_update(
@router.post("/confirm", response_model=BoardRead)
async def confirm_onboarding(
async def confirm_onboarding( # noqa: C901, PLR0912, PLR0915
payload: BoardOnboardingConfirm,
board: Board = Depends(get_board_for_user_write),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(require_admin_auth),
board: Board = BOARD_USER_WRITE_DEP,
session: AsyncSession = SESSION_DEP,
auth: AuthContext = ADMIN_AUTH_DEP,
) -> Board:
"""Confirm onboarding results and provision the board lead agent."""
onboarding = (
await BoardOnboardingSession.objects.filter_by(board_id=board.id)
.order_by(col(BoardOnboardingSession.updated_at).desc())
@@ -409,7 +458,9 @@ async def confirm_onboarding(
if lead_agent.update_cadence:
lead_identity_profile["update_cadence"] = lead_agent.update_cadence
if lead_agent.custom_instructions:
lead_identity_profile["custom_instructions"] = lead_agent.custom_instructions
lead_identity_profile["custom_instructions"] = (
lead_agent.custom_instructions
)
gateway, config = await _gateway_config(session, board)
session.add(board)

View File

@@ -1,12 +1,14 @@
"""Board CRUD and snapshot endpoints."""
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import func
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import (
get_board_for_actor_read,
@@ -47,9 +49,23 @@ from app.services.board_group_snapshot import build_board_group_snapshot
from app.services.board_snapshot import build_board_snapshot
from app.services.organizations import OrganizationContext, board_access_filter
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter(prefix="/boards", tags=["boards"])
AGENT_SESSION_PREFIX = "agent"
SESSION_DEP = Depends(get_session)
ORG_ADMIN_DEP = Depends(require_org_admin)
ORG_MEMBER_DEP = Depends(require_org_member)
BOARD_USER_READ_DEP = Depends(get_board_for_user_read)
BOARD_USER_WRITE_DEP = Depends(get_board_for_user_write)
BOARD_ACTOR_READ_DEP = Depends(get_board_for_actor_read)
GATEWAY_ID_QUERY = Query(default=None)
BOARD_GROUP_ID_QUERY = Query(default=None)
INCLUDE_SELF_QUERY = Query(default=False)
INCLUDE_DONE_QUERY = Query(default=False)
PER_BOARD_TASK_LIMIT_QUERY = Query(default=5, ge=0, le=100)
def _slugify(value: str) -> str:
@@ -83,10 +99,12 @@ async def _require_gateway(
async def _require_gateway_for_create(
payload: BoardCreate,
ctx: OrganizationContext = Depends(require_org_admin),
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = ORG_ADMIN_DEP,
session: AsyncSession = SESSION_DEP,
) -> Gateway:
return await _require_gateway(session, payload.gateway_id, organization_id=ctx.organization.id)
return await _require_gateway(
session, payload.gateway_id, organization_id=ctx.organization.id,
)
async def _require_board_group(
@@ -111,8 +129,8 @@ async def _require_board_group(
async def _require_board_group_for_create(
payload: BoardCreate,
ctx: OrganizationContext = Depends(require_org_admin),
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = ORG_ADMIN_DEP,
session: AsyncSession = SESSION_DEP,
) -> BoardGroup | None:
if payload.board_group_id is None:
return None
@@ -123,6 +141,10 @@ async def _require_board_group_for_create(
)
GATEWAY_CREATE_DEP = Depends(_require_gateway_for_create)
BOARD_GROUP_CREATE_DEP = Depends(_require_board_group_for_create)
async def _apply_board_update(
*,
payload: BoardUpdate,
@@ -132,7 +154,7 @@ async def _apply_board_update(
updates = payload.model_dump(exclude_unset=True)
if "gateway_id" in updates:
await _require_gateway(
session, updates["gateway_id"], organization_id=board.organization_id
session, updates["gateway_id"], organization_id=board.organization_id,
)
if "board_group_id" in updates and updates["board_group_id"] is not None:
await _require_board_group(
@@ -141,13 +163,15 @@ async def _apply_board_update(
organization_id=board.organization_id,
)
crud.apply_updates(board, updates)
if updates.get("board_type") == "goal":
if (
updates.get("board_type") == "goal"
and (not board.objective or not board.success_metrics)
):
# Validate only when explicitly switching to goal boards.
if not board.objective or not board.success_metrics:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Goal boards require objective and success_metrics",
)
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Goal boards require objective and success_metrics",
)
if not board.gateway_id:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -158,7 +182,7 @@ async def _apply_board_update(
async def _board_gateway(
session: AsyncSession, board: Board
session: AsyncSession, board: Board,
) -> tuple[Gateway | None, GatewayClientConfig | None]:
if not board.gateway_id:
return None, None
@@ -218,28 +242,32 @@ async def _cleanup_agent_on_gateway(
@router.get("", response_model=DefaultLimitOffsetPage[BoardRead])
async def list_boards(
gateway_id: UUID | None = Query(default=None),
board_group_id: UUID | None = Query(default=None),
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_member),
gateway_id: UUID | None = GATEWAY_ID_QUERY,
board_group_id: UUID | None = BOARD_GROUP_ID_QUERY,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> DefaultLimitOffsetPage[BoardRead]:
"""List boards visible to the current organization member."""
statement = select(Board).where(board_access_filter(ctx.member, write=False))
if gateway_id is not None:
statement = statement.where(col(Board.gateway_id) == gateway_id)
if board_group_id is not None:
statement = statement.where(col(Board.board_group_id) == board_group_id)
statement = statement.order_by(func.lower(col(Board.name)).asc(), col(Board.created_at).desc())
statement = statement.order_by(
func.lower(col(Board.name)).asc(), col(Board.created_at).desc(),
)
return await paginate(session, statement)
@router.post("", response_model=BoardRead)
async def create_board(
payload: BoardCreate,
_gateway: Gateway = Depends(_require_gateway_for_create),
_board_group: BoardGroup | None = Depends(_require_board_group_for_create),
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
_gateway: Gateway = GATEWAY_CREATE_DEP,
_board_group: BoardGroup | None = BOARD_GROUP_CREATE_DEP,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> Board:
"""Create a board in the active organization."""
data = payload.model_dump()
data["organization_id"] = ctx.organization.id
return await crud.create(session, Board, **data)
@@ -247,27 +275,31 @@ async def create_board(
@router.get("/{board_id}", response_model=BoardRead)
def get_board(
board: Board = Depends(get_board_for_user_read),
board: Board = BOARD_USER_READ_DEP,
) -> Board:
"""Get a board by id."""
return board
@router.get("/{board_id}/snapshot", response_model=BoardSnapshot)
async def get_board_snapshot(
board: Board = Depends(get_board_for_actor_read),
session: AsyncSession = Depends(get_session),
board: Board = BOARD_ACTOR_READ_DEP,
session: AsyncSession = SESSION_DEP,
) -> BoardSnapshot:
"""Get a board snapshot view model."""
return await build_board_snapshot(session, board)
@router.get("/{board_id}/group-snapshot", response_model=BoardGroupSnapshot)
async def get_board_group_snapshot(
include_self: bool = Query(default=False),
include_done: bool = Query(default=False),
per_board_task_limit: int = Query(default=5, ge=0, le=100),
board: Board = Depends(get_board_for_actor_read),
session: AsyncSession = Depends(get_session),
*,
include_self: bool = INCLUDE_SELF_QUERY,
include_done: bool = INCLUDE_DONE_QUERY,
per_board_task_limit: int = PER_BOARD_TASK_LIMIT_QUERY,
board: Board = BOARD_ACTOR_READ_DEP,
session: AsyncSession = SESSION_DEP,
) -> BoardGroupSnapshot:
"""Get a grouped snapshot across related boards."""
return await build_board_group_snapshot(
session,
board=board,
@@ -280,19 +312,23 @@ async def get_board_group_snapshot(
@router.patch("/{board_id}", response_model=BoardRead)
async def update_board(
payload: BoardUpdate,
session: AsyncSession = Depends(get_session),
board: Board = Depends(get_board_for_user_write),
session: AsyncSession = SESSION_DEP,
board: Board = BOARD_USER_WRITE_DEP,
) -> Board:
"""Update mutable board properties."""
return await _apply_board_update(payload=payload, session=session, board=board)
@router.delete("/{board_id}", response_model=OkResponse)
async def delete_board(
session: AsyncSession = Depends(get_session),
board: Board = Depends(get_board_for_user_write),
session: AsyncSession = SESSION_DEP,
board: Board = BOARD_USER_WRITE_DEP,
) -> OkResponse:
"""Delete a board and all dependent records."""
agents = await Agent.objects.filter_by(board_id=board.id).all(session)
task_ids = list(await session.exec(select(Task.id).where(Task.board_id == board.id)))
task_ids = list(
await session.exec(select(Task.id).where(Task.board_id == board.id)),
)
config, client_config = await _board_gateway(session, board)
if config and client_config:
@@ -307,20 +343,31 @@ async def delete_board(
if task_ids:
await crud.delete_where(
session, ActivityEvent, col(ActivityEvent.task_id).in_(task_ids), commit=False
session,
ActivityEvent,
col(ActivityEvent.task_id).in_(task_ids),
commit=False,
)
await crud.delete_where(session, TaskDependency, col(TaskDependency.board_id) == board.id)
await crud.delete_where(session, TaskFingerprint, col(TaskFingerprint.board_id) == board.id)
await crud.delete_where(
session, TaskDependency, col(TaskDependency.board_id) == board.id,
)
await crud.delete_where(
session, TaskFingerprint, col(TaskFingerprint.board_id) == board.id,
)
# Approvals can reference tasks and agents, so delete before both.
await crud.delete_where(session, Approval, col(Approval.board_id) == board.id)
await crud.delete_where(session, BoardMemory, col(BoardMemory.board_id) == board.id)
await crud.delete_where(
session, BoardOnboardingSession, col(BoardOnboardingSession.board_id) == board.id
session,
BoardOnboardingSession,
col(BoardOnboardingSession.board_id) == board.id,
)
await crud.delete_where(
session, OrganizationBoardAccess, col(OrganizationBoardAccess.board_id) == board.id
session,
OrganizationBoardAccess,
col(OrganizationBoardAccess.board_id) == board.id,
)
await crud.delete_where(
session,
@@ -328,14 +375,17 @@ async def delete_board(
col(OrganizationInviteBoardAccess.board_id) == board.id,
)
# Tasks reference agents (assigned_agent_id) and have dependents (fingerprints/dependencies), so
# Tasks reference agents and have dependent records.
# delete tasks before agents.
await crud.delete_where(session, Task, col(Task.board_id) == board.id)
if agents:
agent_ids = [agent.id for agent in agents]
await crud.delete_where(
session, ActivityEvent, col(ActivityEvent.agent_id).in_(agent_ids), commit=False
session,
ActivityEvent,
col(ActivityEvent.agent_id).in_(agent_ids),
commit=False,
)
await crud.delete_where(session, Agent, col(Agent.id).in_(agent_ids))
await session.delete(board)

View File

@@ -1,16 +1,17 @@
"""Organization management endpoints and membership/invite flows."""
from __future__ import annotations
import secrets
from typing import Any, Sequence
from typing import TYPE_CHECKING, Any, Sequence
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import require_org_admin, require_org_member
from app.core.auth import AuthContext, get_auth_context
from app.core.auth import get_auth_context
from app.core.time import utcnow
from app.db import crud
from app.db.pagination import paginate
@@ -63,10 +64,21 @@ from app.services.organizations import (
set_active_organization,
)
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.auth import AuthContext
router = APIRouter(prefix="/organizations", tags=["organizations"])
SESSION_DEP = Depends(get_session)
AUTH_DEP = Depends(get_auth_context)
ORG_MEMBER_DEP = Depends(require_org_member)
ORG_ADMIN_DEP = Depends(require_org_admin)
def _member_to_read(member: OrganizationMember, user: User | None) -> OrganizationMemberRead:
def _member_to_read(
member: OrganizationMember, user: User | None,
) -> OrganizationMemberRead:
model = OrganizationMemberRead.model_validate(member, from_attributes=True)
if user is not None:
model.user = OrganizationUserRead.model_validate(user, from_attributes=True)
@@ -100,9 +112,10 @@ async def _require_org_invite(
@router.post("", response_model=OrganizationRead)
async def create_organization(
payload: OrganizationCreate,
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = SESSION_DEP,
auth: AuthContext = AUTH_DEP,
) -> OrganizationRead:
"""Create an organization and assign the caller as owner."""
if auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
name = payload.name.strip()
@@ -110,7 +123,9 @@ async def create_organization(
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
existing = (
await session.exec(
select(Organization).where(func.lower(col(Organization.name)) == name.lower())
select(Organization).where(
func.lower(col(Organization.name)) == name.lower(),
),
)
).first()
if existing is not None:
@@ -140,19 +155,25 @@ async def create_organization(
@router.get("/me/list", response_model=list[OrganizationListItem])
async def list_my_organizations(
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = SESSION_DEP,
auth: AuthContext = AUTH_DEP,
) -> list[OrganizationListItem]:
"""List organizations where the current user is a member."""
if auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
await get_active_membership(session, auth.user)
db_user = await User.objects.by_id(auth.user.id).first(session)
active_id = db_user.active_organization_id if db_user else auth.user.active_organization_id
active_id = (
db_user.active_organization_id if db_user else auth.user.active_organization_id
)
statement = (
select(Organization, OrganizationMember)
.join(OrganizationMember, col(OrganizationMember.organization_id) == col(Organization.id))
.join(
OrganizationMember,
col(OrganizationMember.organization_id) == col(Organization.id),
)
.where(col(OrganizationMember.user_id) == auth.user.id)
.order_by(func.lower(col(Organization.name)).asc())
)
@@ -171,30 +192,37 @@ async def list_my_organizations(
@router.patch("/me/active", response_model=OrganizationRead)
async def set_active_org(
payload: OrganizationActiveUpdate,
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = SESSION_DEP,
auth: AuthContext = AUTH_DEP,
) -> OrganizationRead:
"""Set the caller's active organization."""
if auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
member = await set_active_organization(
session, user=auth.user, organization_id=payload.organization_id
session, user=auth.user, organization_id=payload.organization_id,
)
organization = await Organization.objects.by_id(member.organization_id).first(
session,
)
organization = await Organization.objects.by_id(member.organization_id).first(session)
if organization is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return OrganizationRead.model_validate(organization, from_attributes=True)
@router.get("/me", response_model=OrganizationRead)
async def get_my_org(ctx: OrganizationContext = Depends(require_org_member)) -> OrganizationRead:
async def get_my_org(
ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> OrganizationRead:
"""Return the caller's active organization."""
return OrganizationRead.model_validate(ctx.organization, from_attributes=True)
@router.delete("/me", response_model=OkResponse)
async def delete_my_org(
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> OkResponse:
"""Delete the active organization and related entities."""
if ctx.member.role != "owner":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@@ -206,28 +234,39 @@ async def delete_my_org(
task_ids = select(Task.id).where(col(Task.board_id).in_(board_ids))
agent_ids = select(Agent.id).where(col(Agent.board_id).in_(board_ids))
member_ids = select(OrganizationMember.id).where(
col(OrganizationMember.organization_id) == org_id
col(OrganizationMember.organization_id) == org_id,
)
invite_ids = select(OrganizationInvite.id).where(
col(OrganizationInvite.organization_id) == org_id
col(OrganizationInvite.organization_id) == org_id,
)
group_ids = select(BoardGroup.id).where(col(BoardGroup.organization_id) == org_id)
await crud.delete_where(
session, ActivityEvent, col(ActivityEvent.task_id).in_(task_ids), commit=False
session, ActivityEvent, col(ActivityEvent.task_id).in_(task_ids), commit=False,
)
await crud.delete_where(
session, ActivityEvent, col(ActivityEvent.agent_id).in_(agent_ids), commit=False
session,
ActivityEvent,
col(ActivityEvent.agent_id).in_(agent_ids),
commit=False,
)
await crud.delete_where(
session, TaskDependency, col(TaskDependency.board_id).in_(board_ids), commit=False
session,
TaskDependency,
col(TaskDependency.board_id).in_(board_ids),
commit=False,
)
await crud.delete_where(
session, TaskFingerprint, col(TaskFingerprint.board_id).in_(board_ids), commit=False
session,
TaskFingerprint,
col(TaskFingerprint.board_id).in_(board_ids),
commit=False,
)
await crud.delete_where(session, Approval, col(Approval.board_id).in_(board_ids), commit=False)
await crud.delete_where(
session, BoardMemory, col(BoardMemory.board_id).in_(board_ids), commit=False
session, Approval, col(Approval.board_id).in_(board_ids), commit=False,
)
await crud.delete_where(
session, BoardMemory, col(BoardMemory.board_id).in_(board_ids), commit=False,
)
await crud.delete_where(
session,
@@ -259,9 +298,15 @@ async def delete_my_org(
col(OrganizationInviteBoardAccess.organization_invite_id).in_(invite_ids),
commit=False,
)
await crud.delete_where(session, Task, col(Task.board_id).in_(board_ids), commit=False)
await crud.delete_where(session, Agent, col(Agent.board_id).in_(board_ids), commit=False)
await crud.delete_where(session, Board, col(Board.organization_id) == org_id, commit=False)
await crud.delete_where(
session, Task, col(Task.board_id).in_(board_ids), commit=False,
)
await crud.delete_where(
session, Agent, col(Agent.board_id).in_(board_ids), commit=False,
)
await crud.delete_where(
session, Board, col(Board.organization_id) == org_id, commit=False,
)
await crud.delete_where(
session,
BoardGroupMemory,
@@ -269,9 +314,11 @@ async def delete_my_org(
commit=False,
)
await crud.delete_where(
session, BoardGroup, col(BoardGroup.organization_id) == org_id, commit=False
session, BoardGroup, col(BoardGroup.organization_id) == org_id, commit=False,
)
await crud.delete_where(
session, Gateway, col(Gateway.organization_id) == org_id, commit=False,
)
await crud.delete_where(session, Gateway, col(Gateway.organization_id) == org_id, commit=False)
await crud.delete_where(
session,
OrganizationInvite,
@@ -291,32 +338,39 @@ async def delete_my_org(
active_organization_id=None,
commit=False,
)
await crud.delete_where(session, Organization, col(Organization.id) == org_id, commit=False)
await crud.delete_where(
session, Organization, col(Organization.id) == org_id, commit=False,
)
await session.commit()
return OkResponse()
@router.get("/me/member", response_model=OrganizationMemberRead)
async def get_my_membership(
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_member),
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> OrganizationMemberRead:
"""Get the caller's membership record in the active organization."""
user = await User.objects.by_id(ctx.member.user_id).first(session)
access_rows = await OrganizationBoardAccess.objects.filter_by(
organization_member_id=ctx.member.id
organization_member_id=ctx.member.id,
).all(session)
model = _member_to_read(ctx.member, user)
model.board_access = [
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows
OrganizationBoardAccessRead.model_validate(row, from_attributes=True)
for row in access_rows
]
return model
@router.get("/me/members", response_model=DefaultLimitOffsetPage[OrganizationMemberRead])
@router.get(
"/me/members", response_model=DefaultLimitOffsetPage[OrganizationMemberRead],
)
async def list_org_members(
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_member),
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> DefaultLimitOffsetPage[OrganizationMemberRead]:
"""List members for the active organization."""
statement = (
select(OrganizationMember, User)
.join(User, col(User.id) == col(OrganizationMember.user_id))
@@ -336,9 +390,10 @@ async def list_org_members(
@router.get("/me/members/{member_id}", response_model=OrganizationMemberRead)
async def get_org_member(
member_id: UUID,
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_member),
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> OrganizationMemberRead:
"""Get a specific organization member by id."""
member = await _require_org_member(
session,
organization_id=ctx.organization.id,
@@ -348,11 +403,12 @@ async def get_org_member(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
user = await User.objects.by_id(member.user_id).first(session)
access_rows = await OrganizationBoardAccess.objects.filter_by(
organization_member_id=member.id
organization_member_id=member.id,
).all(session)
model = _member_to_read(member, user)
model.board_access = [
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows
OrganizationBoardAccessRead.model_validate(row, from_attributes=True)
for row in access_rows
]
return model
@@ -361,9 +417,10 @@ async def get_org_member(
async def update_org_member(
member_id: UUID,
payload: OrganizationMemberUpdate,
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> OrganizationMemberRead:
"""Update a member's role in the organization."""
member = await _require_org_member(
session,
organization_id=ctx.organization.id,
@@ -382,9 +439,10 @@ async def update_org_member(
async def update_member_access(
member_id: UUID,
payload: OrganizationMemberAccessUpdate,
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> OrganizationMemberRead:
"""Update board-level access settings for a member."""
member = await _require_org_member(
session,
organization_id=ctx.organization.id,
@@ -395,7 +453,9 @@ async def update_member_access(
if board_ids:
valid_board_ids = {
board.id
for board in await Board.objects.filter_by(organization_id=ctx.organization.id)
for board in await Board.objects.filter_by(
organization_id=ctx.organization.id,
)
.filter(col(Board.id).in_(board_ids))
.all(session)
}
@@ -412,9 +472,10 @@ async def update_member_access(
@router.delete("/me/members/{member_id}", response_model=OkResponse)
async def remove_org_member(
member_id: UUID,
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> OkResponse:
"""Remove a member from the active organization."""
member = await _require_org_member(
session,
organization_id=ctx.organization.id,
@@ -432,7 +493,9 @@ async def remove_org_member(
)
if member.role == "owner":
owners = (
await OrganizationMember.objects.filter_by(organization_id=ctx.organization.id)
await OrganizationMember.objects.filter_by(
organization_id=ctx.organization.id,
)
.filter(col(OrganizationMember.role) == "owner")
.all(session)
)
@@ -463,7 +526,9 @@ async def remove_org_member(
user.active_organization_id = fallback_membership
else:
user.active_organization_id = (
fallback_membership.organization_id if fallback_membership is not None else None
fallback_membership.organization_id
if fallback_membership is not None
else None
)
session.add(user)
@@ -471,11 +536,14 @@ async def remove_org_member(
return OkResponse()
@router.get("/me/invites", response_model=DefaultLimitOffsetPage[OrganizationInviteRead])
@router.get(
"/me/invites", response_model=DefaultLimitOffsetPage[OrganizationInviteRead],
)
async def list_org_invites(
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> DefaultLimitOffsetPage[OrganizationInviteRead]:
"""List pending invites for the active organization."""
statement = (
OrganizationInvite.objects.filter_by(organization_id=ctx.organization.id)
.filter(col(OrganizationInvite.accepted_at).is_(None))
@@ -488,9 +556,10 @@ async def list_org_invites(
@router.post("/me/invites", response_model=OrganizationInviteRead)
async def create_org_invite(
payload: OrganizationInviteCreate,
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> OrganizationInviteRead:
"""Create an organization invite for an email address."""
email = normalize_invited_email(payload.invited_email)
if not email:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
@@ -526,13 +595,17 @@ async def create_org_invite(
if board_ids:
valid_board_ids = {
board.id
for board in await Board.objects.filter_by(organization_id=ctx.organization.id)
for board in await Board.objects.filter_by(
organization_id=ctx.organization.id,
)
.filter(col(Board.id).in_(board_ids))
.all(session)
}
if valid_board_ids != board_ids:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
await apply_invite_board_access(session, invite=invite, entries=payload.board_access)
await apply_invite_board_access(
session, invite=invite, entries=payload.board_access,
)
await session.commit()
await session.refresh(invite)
return OrganizationInviteRead.model_validate(invite, from_attributes=True)
@@ -541,9 +614,10 @@ async def create_org_invite(
@router.delete("/me/invites/{invite_id}", response_model=OrganizationInviteRead)
async def revoke_org_invite(
invite_id: UUID,
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> OrganizationInviteRead:
"""Revoke a pending invite from the active organization."""
invite = await _require_org_invite(
session,
organization_id=ctx.organization.id,
@@ -562,9 +636,10 @@ async def revoke_org_invite(
@router.post("/invites/accept", response_model=OrganizationMemberRead)
async def accept_org_invite(
payload: OrganizationInviteAccept,
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = SESSION_DEP,
auth: AuthContext = AUTH_DEP,
) -> OrganizationMemberRead:
"""Accept an invite and return resulting membership."""
if auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
invite = await OrganizationInvite.objects.filter(
@@ -573,11 +648,13 @@ async def accept_org_invite(
).first(session)
if invite is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if invite.invited_email and auth.user.email:
if normalize_invited_email(invite.invited_email) != normalize_invited_email(
auth.user.email
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if (
invite.invited_email
and auth.user.email
and normalize_invited_email(invite.invited_email)
!= normalize_invited_email(auth.user.email)
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
existing = await get_member(
session,

View File

@@ -1,41 +1,54 @@
"""API-level thin wrapper around query-set helpers with HTTP conveniences."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Generic, TypeVar
from typing import TYPE_CHECKING, Generic, TypeVar
from fastapi import HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar
from app.db.queryset import QuerySet, qs
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar
ModelT = TypeVar("ModelT")
@dataclass(frozen=True)
class APIQuerySet(Generic[ModelT]):
"""Immutable query-set wrapper tailored for API-layer usage."""
queryset: QuerySet[ModelT]
@property
def statement(self) -> SelectOfScalar[ModelT]:
"""Expose the underlying SQL statement for advanced composition."""
return self.queryset.statement
def filter(self, *criteria: Any) -> APIQuerySet[ModelT]:
def filter(self, *criteria: object) -> APIQuerySet[ModelT]:
"""Return a new queryset with additional SQL criteria applied."""
return APIQuerySet(self.queryset.filter(*criteria))
def order_by(self, *ordering: Any) -> APIQuerySet[ModelT]:
def order_by(self, *ordering: object) -> APIQuerySet[ModelT]:
"""Return a new queryset with ordering clauses applied."""
return APIQuerySet(self.queryset.order_by(*ordering))
def limit(self, value: int) -> APIQuerySet[ModelT]:
"""Return a new queryset with a row limit applied."""
return APIQuerySet(self.queryset.limit(value))
def offset(self, value: int) -> APIQuerySet[ModelT]:
"""Return a new queryset with an offset applied."""
return APIQuerySet(self.queryset.offset(value))
async def all(self, session: AsyncSession) -> list[ModelT]:
"""Fetch all rows for the current queryset."""
return await self.queryset.all(session)
async def first(self, session: AsyncSession) -> ModelT | None:
"""Fetch the first row for the current queryset, if present."""
return await self.queryset.first(session)
async def first_or_404(
@@ -44,6 +57,7 @@ class APIQuerySet(Generic[ModelT]):
*,
detail: str | None = None,
) -> ModelT:
"""Fetch the first row or raise HTTP 404 when no row exists."""
obj = await self.first(session)
if obj is not None:
return obj
@@ -53,4 +67,5 @@ class APIQuerySet(Generic[ModelT]):
def api_qs(model: type[ModelT]) -> APIQuerySet[ModelT]:
"""Create an APIQuerySet for a SQLModel class."""
return APIQuerySet(qs(model))

View File

@@ -1,3 +1,5 @@
"""API routes for searching and fetching souls-directory markdown entries."""
from __future__ import annotations
import re
@@ -13,6 +15,7 @@ from app.schemas.souls_directory import (
from app.services import souls_directory
router = APIRouter(prefix="/souls-directory", tags=["souls-directory"])
ADMIN_OR_AGENT_DEP = Depends(require_admin_or_agent)
_SAFE_SEGMENT_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$")
_SAFE_SLUG_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$")
@@ -41,8 +44,9 @@ def _validate_segment(value: str, *, field: str) -> str:
async def search(
q: str = Query(default="", min_length=0),
limit: int = Query(default=20, ge=1, le=100),
_actor: ActorContext = Depends(require_admin_or_agent),
_actor: ActorContext = ADMIN_OR_AGENT_DEP,
) -> SoulsDirectorySearchResponse:
"""Search souls-directory entries by handle/slug query text."""
refs = await souls_directory.list_souls_directory_refs()
matches = souls_directory.search_souls(refs, query=q, limit=limit)
items = [
@@ -62,12 +66,23 @@ async def search(
async def get_markdown(
handle: str,
slug: str,
_actor: ActorContext = Depends(require_admin_or_agent),
_actor: ActorContext = ADMIN_OR_AGENT_DEP,
) -> SoulsDirectoryMarkdownResponse:
"""Fetch markdown content for a validated souls-directory handle and slug."""
safe_handle = _validate_segment(handle, field="handle")
safe_slug = _validate_segment(slug.removesuffix(".md"), field="slug")
try:
content = await souls_directory.fetch_soul_markdown(handle=safe_handle, slug=safe_slug)
content = await souls_directory.fetch_soul_markdown(
handle=safe_handle,
slug=safe_slug,
)
except Exception as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
return SoulsDirectoryMarkdownResponse(handle=safe_handle, slug=safe_slug, content=content)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=str(exc),
) from exc
return SoulsDirectoryMarkdownResponse(
handle=safe_handle,
slug=safe_slug,
content=content,
)

View File

@@ -1,17 +1,19 @@
"""Task API routes for listing, streaming, and mutating board tasks."""
from __future__ import annotations
import asyncio
import json
from collections import deque
from collections.abc import AsyncIterator, Sequence
from contextlib import suppress
from datetime import datetime, timezone
from typing import cast
from typing import TYPE_CHECKING, cast
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import asc, desc, or_
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import Select
from sse_starlette.sse import EventSourceResponse
@@ -23,13 +25,16 @@ from app.api.deps import (
require_admin_auth,
require_admin_or_agent,
)
from app.core.auth import AuthContext
from app.core.time import utcnow
from app.db import crud
from app.db.pagination import paginate
from app.db.session import async_session_maker, get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
from app.integrations.openclaw_gateway import (
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.activity_events import ActivityEvent
from app.models.agents import Agent
from app.models.approvals import Approval
@@ -41,7 +46,13 @@ from app.models.tasks import Task
from app.schemas.common import OkResponse
from app.schemas.errors import BlockedTaskError
from app.schemas.pagination import DefaultLimitOffsetPage
from app.schemas.tasks import TaskCommentCreate, TaskCommentRead, TaskCreate, TaskRead, TaskUpdate
from app.schemas.tasks import (
TaskCommentCreate,
TaskCommentRead,
TaskCreate,
TaskRead,
TaskUpdate,
)
from app.services.activity_log import record_activity
from app.services.mentions import extract_mentions, matches_agent_mention
from app.services.organizations import require_board_access
@@ -54,6 +65,11 @@ from app.services.task_dependencies import (
validate_dependency_update,
)
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.auth import AuthContext
router = APIRouter(prefix="/boards/{board_id}/tasks", tags=["tasks"])
ALLOWED_STATUSES = {"inbox", "in_progress", "review", "done"}
@@ -66,6 +82,14 @@ TASK_EVENT_TYPES = {
SSE_SEEN_MAX = 2000
TASK_SNIPPET_MAX_LEN = 500
TASK_SNIPPET_TRUNCATED_LEN = 497
BOARD_READ_DEP = Depends(get_board_for_actor_read)
ACTOR_DEP = Depends(require_admin_or_agent)
SINCE_QUERY = Query(default=None)
STATUS_QUERY = Query(default=None, alias="status")
BOARD_WRITE_DEP = Depends(get_board_for_user_write)
SESSION_DEP = Depends(get_session)
ADMIN_AUTH_DEP = Depends(require_admin_auth)
TASK_DEP = Depends(get_task_or_404)
def _comment_validation_error() -> HTTPException:
@@ -98,6 +122,7 @@ async def has_valid_recent_comment(
agent_id: UUID | None,
since: datetime | None,
) -> bool:
"""Check whether the task has a recent non-empty comment by the agent."""
if agent_id is None or since is None:
return False
statement = (
@@ -180,8 +205,8 @@ async def _reconcile_dependents_for_dependency_toggle(
await session.exec(
select(Task)
.where(col(Task.board_id) == board_id)
.where(col(Task.id).in_(dependent_ids))
)
.where(col(Task.id).in_(dependent_ids)),
),
)
reopened = previous_status == "done" and dependency_task.status != "done"
@@ -204,7 +229,10 @@ async def _reconcile_dependents_for_dependency_toggle(
session,
event_type="task.status_changed",
task_id=dependent.id,
message=f"Task returned to inbox: dependency reopened ({dependency_task.title}).",
message=(
"Task returned to inbox: dependency reopened "
f"({dependency_task.title})."
),
agent_id=actor_agent_id,
)
else:
@@ -230,7 +258,9 @@ async def _fetch_task_events(
board_id: UUID,
since: datetime,
) -> list[tuple[ActivityEvent, Task | None]]:
task_ids = list(await session.exec(select(Task.id).where(col(Task.board_id) == board_id)))
task_ids = list(
await session.exec(select(Task.id).where(col(Task.board_id) == board_id)),
)
if not task_ids:
return []
statement = cast(
@@ -249,7 +279,9 @@ def _serialize_comment(event: ActivityEvent) -> dict[str, object]:
return TaskCommentRead.model_validate(event).model_dump(mode="json")
async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None:
async def _gateway_config(
session: AsyncSession, board: Board,
) -> GatewayClientConfig | None:
if not board.gateway_id:
return None
gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
@@ -303,7 +335,10 @@ async def _notify_agent_on_task_assign(
message = (
"TASK ASSIGNED\n"
+ "\n".join(details)
+ "\n\nTake action: open the task and begin work. Post updates as task comments."
+ (
"\n\nTake action: open the task and begin work. "
"Post updates as task comments."
)
)
try:
await _send_agent_task_message(
@@ -442,17 +477,18 @@ async def _notify_lead_on_task_unassigned(
@router.get("/stream")
async def stream_tasks(
async def stream_tasks( # noqa: C901
request: Request,
board: Board = Depends(get_board_for_actor_read),
actor: ActorContext = Depends(require_admin_or_agent),
since: str | None = Query(default=None),
board: Board = BOARD_READ_DEP,
_actor: ActorContext = ACTOR_DEP,
since: str | None = SINCE_QUERY,
) -> EventSourceResponse:
"""Stream task and task-comment events as SSE payloads."""
since_dt = _parse_since(since) or utcnow()
seen_ids: set[UUID] = set()
seen_queue: deque[UUID] = deque()
async def event_generator() -> AsyncIterator[dict[str, str]]:
async def event_generator() -> AsyncIterator[dict[str, str]]: # noqa: C901
last_seen = since_dt
while True:
if await request.is_disconnected():
@@ -510,7 +546,7 @@ async def stream_tasks(
"depends_on_task_ids": dep_list,
"blocked_by_task_ids": blocked_by,
"is_blocked": bool(blocked_by),
}
},
)
.model_dump(mode="json")
)
@@ -521,14 +557,15 @@ async def stream_tasks(
@router.get("", response_model=DefaultLimitOffsetPage[TaskRead])
async def list_tasks(
status_filter: str | None = Query(default=None, alias="status"),
async def list_tasks( # noqa: C901
status_filter: str | None = STATUS_QUERY,
assigned_agent_id: UUID | None = None,
unassigned: bool | None = None,
board: Board = Depends(get_board_for_actor_read),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
board: Board = BOARD_READ_DEP,
session: AsyncSession = SESSION_DEP,
_actor: ActorContext = ACTOR_DEP,
) -> DefaultLimitOffsetPage[TaskRead]:
"""List board tasks with optional status and assignment filters."""
statement = select(Task).where(Task.board_id == board.id)
if status_filter:
statuses = [s.strip() for s in status_filter.split(",") if s.strip()]
@@ -550,7 +587,9 @@ async def list_tasks(
if not tasks:
return []
task_ids = [task.id for task in tasks]
deps_map = await dependency_ids_by_task_id(session, board_id=board.id, task_ids=task_ids)
deps_map = await dependency_ids_by_task_id(
session, board_id=board.id, task_ids=task_ids,
)
dep_ids: list[UUID] = []
for value in deps_map.values():
dep_ids.extend(value)
@@ -563,7 +602,9 @@ async def list_tasks(
output: list[TaskRead] = []
for task in tasks:
dep_list = deps_map.get(task.id, [])
blocked_by = blocked_by_dependency_ids(dependency_ids=dep_list, status_by_id=dep_status)
blocked_by = blocked_by_dependency_ids(
dependency_ids=dep_list, status_by_id=dep_status,
)
if task.status == "done":
blocked_by = []
output.append(
@@ -572,8 +613,8 @@ async def list_tasks(
"depends_on_task_ids": dep_list,
"blocked_by_task_ids": blocked_by,
"is_blocked": bool(blocked_by),
}
)
},
),
)
return output
@@ -583,10 +624,11 @@ async def list_tasks(
@router.post("", response_model=TaskRead, responses={409: {"model": BlockedTaskError}})
async def create_task(
payload: TaskCreate,
board: Board = Depends(get_board_for_user_write),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(require_admin_auth),
board: Board = BOARD_WRITE_DEP,
session: AsyncSession = SESSION_DEP,
auth: AuthContext = ADMIN_AUTH_DEP,
) -> TaskRead:
"""Create a task and initialize dependency rows."""
data = payload.model_dump()
depends_on_task_ids = cast(list[UUID], data.pop("depends_on_task_ids", []) or [])
@@ -606,7 +648,9 @@ async def create_task(
board_id=board.id,
dependency_ids=normalized_deps,
)
blocked_by = blocked_by_dependency_ids(dependency_ids=normalized_deps, status_by_id=dep_status)
blocked_by = blocked_by_dependency_ids(
dependency_ids=normalized_deps, status_by_id=dep_status,
)
if blocked_by and (task.assigned_agent_id is not None or task.status != "inbox"):
raise _blocked_task_error(blocked_by)
session.add(task)
@@ -618,7 +662,7 @@ async def create_task(
board_id=board.id,
task_id=task.id,
depends_on_task_id=dep_id,
)
),
)
await session.commit()
await session.refresh(task)
@@ -632,7 +676,9 @@ async def create_task(
await session.commit()
await _notify_lead_on_task_create(session=session, board=board, task=task)
if task.assigned_agent_id:
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(
session,
)
if assigned_agent:
await _notify_agent_on_task_assign(
session=session,
@@ -645,7 +691,7 @@ async def create_task(
"depends_on_task_ids": normalized_deps,
"blocked_by_task_ids": blocked_by,
"is_blocked": bool(blocked_by),
}
},
)
@@ -654,12 +700,13 @@ async def create_task(
response_model=TaskRead,
responses={409: {"model": BlockedTaskError}},
)
async def update_task(
async def update_task( # noqa: C901, PLR0912, PLR0915
payload: TaskUpdate,
task: Task = Depends(get_task_or_404),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
task: Task = TASK_DEP,
session: AsyncSession = SESSION_DEP,
actor: ActorContext = ACTOR_DEP,
) -> TaskRead:
"""Update task status, assignment, comment, and dependency state."""
if task.board_id is None:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -676,7 +723,9 @@ async def update_task(
previous_assigned = task.assigned_agent_id
updates = payload.model_dump(exclude_unset=True)
comment = updates.pop("comment", None)
depends_on_task_ids = cast(list[UUID] | None, updates.pop("depends_on_task_ids", None))
depends_on_task_ids = cast(
list[UUID] | None, updates.pop("depends_on_task_ids", None),
)
requested_fields = set(updates)
if comment is not None:
@@ -685,7 +734,9 @@ async def update_task(
requested_fields.add("depends_on_task_ids")
async def _current_dep_ids() -> list[UUID]:
deps_map = await dependency_ids_by_task_id(session, board_id=board_id, task_ids=[task.id])
deps_map = await dependency_ids_by_task_id(
session, board_id=board_id, task_ids=[task.id],
)
return deps_map.get(task.id, [])
async def _blocked_by(dep_ids: Sequence[UUID]) -> list[UUID]:
@@ -696,16 +747,20 @@ async def update_task(
board_id=board_id,
dependency_ids=list(dep_ids),
)
return blocked_by_dependency_ids(dependency_ids=list(dep_ids), status_by_id=dep_status)
return blocked_by_dependency_ids(
dependency_ids=list(dep_ids), status_by_id=dep_status,
)
# Lead agent: delegation only (assign/unassign, resolve review, manage dependencies).
# Lead agent: delegation only.
# Assign/unassign, resolve review, and manage dependencies.
if actor.actor_type == "agent" and actor.agent and actor.agent.is_board_lead:
allowed_fields = {"assigned_agent_id", "status", "depends_on_task_ids"}
if comment is not None or not requested_fields.issubset(allowed_fields):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=(
"Board leads can only assign/unassign tasks, update dependencies, or resolve review tasks."
"Board leads can only assign/unassign tasks, update "
"dependencies, or resolve review tasks."
),
)
@@ -745,7 +800,11 @@ async def update_task(
status_code=status.HTTP_403_FORBIDDEN,
detail="Board leads cannot assign tasks to themselves.",
)
if agent.board_id and task.board_id and agent.board_id != task.board_id:
if (
agent.board_id
and task.board_id
and agent.board_id != task.board_id
):
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
task.assigned_agent_id = agent.id
else:
@@ -755,12 +814,18 @@ async def update_task(
if task.status != "review":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Board leads can only change status when a task is in review.",
detail=(
"Board leads can only change status when a task is "
"in review."
),
)
if updates["status"] not in {"done", "inbox"}:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Board leads can only move review tasks to done or inbox.",
detail=(
"Board leads can only move review tasks to done "
"or inbox."
),
)
if updates["status"] == "inbox":
task.assigned_agent_id = None
@@ -793,7 +858,9 @@ async def update_task(
await session.refresh(task)
if task.assigned_agent_id and task.assigned_agent_id != previous_assigned:
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(
session,
)
if assigned_agent:
board = (
await Board.objects.by_id(task.board_id).first(session)
@@ -817,14 +884,18 @@ async def update_task(
"depends_on_task_ids": dep_ids,
"blocked_by_task_ids": blocked_ids,
"is_blocked": bool(blocked_ids),
}
},
)
# Non-lead agent: can only change status + comment, and cannot start blocked tasks.
if actor.actor_type == "agent":
if actor.agent and actor.agent.board_id and task.board_id:
if actor.agent.board_id != task.board_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if (
actor.agent
and actor.agent.board_id
and task.board_id
and actor.agent.board_id != task.board_id
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
allowed_fields = {"status", "comment"}
if depends_on_task_ids is not None or not set(updates).issubset(allowed_fields):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
@@ -858,14 +929,16 @@ async def update_task(
)
effective_deps = (
admin_normalized_deps if admin_normalized_deps is not None else await _current_dep_ids()
admin_normalized_deps
if admin_normalized_deps is not None
else await _current_dep_ids()
)
blocked_ids = await _blocked_by(effective_deps)
target_status = cast(str, updates.get("status", task.status))
if blocked_ids and not (task.status == "done" and target_status == "done"):
# Blocked tasks cannot be assigned or moved out of inbox. If the task is already in
# flight, force it back to inbox and unassign it.
# Blocked tasks cannot be assigned or moved out of inbox.
# If the task is already in flight, force it back to inbox and unassign it.
task.status = "inbox"
task.assigned_agent_id = None
task.in_progress_at = None
@@ -910,7 +983,9 @@ async def update_task(
event_type="task.comment",
message=comment,
task_id=task.id,
agent_id=actor.agent.id if actor.actor_type == "agent" and actor.agent else None,
agent_id=actor.agent.id
if actor.actor_type == "agent" and actor.agent
else None,
)
session.add(event)
await session.commit()
@@ -921,7 +996,9 @@ async def update_task(
else:
event_type = "task.updated"
message = f"Task updated: {task.title}."
actor_agent_id = actor.agent.id if actor.actor_type == "agent" and actor.agent else None
actor_agent_id = (
actor.agent.id if actor.actor_type == "agent" and actor.agent else None
)
record_activity(
session,
event_type=event_type,
@@ -938,23 +1015,34 @@ async def update_task(
)
await session.commit()
if task.status == "inbox" and task.assigned_agent_id is None:
if previous_status != "inbox" or previous_assigned is not None:
board = (
await Board.objects.by_id(task.board_id).first(session) if task.board_id else None
if (
task.status == "inbox"
and task.assigned_agent_id is None
and (previous_status != "inbox" or previous_assigned is not None)
):
board = (
await Board.objects.by_id(task.board_id).first(session)
if task.board_id
else None
)
if board:
await _notify_lead_on_task_unassigned(
session=session,
board=board,
task=task,
)
if board:
await _notify_lead_on_task_unassigned(
session=session,
board=board,
task=task,
)
if task.assigned_agent_id and task.assigned_agent_id != previous_assigned:
if actor.actor_type == "agent" and actor.agent and task.assigned_agent_id == actor.agent.id:
if (
actor.actor_type == "agent"
and actor.agent
and task.assigned_agent_id == actor.agent.id
):
# Don't notify the actor about their own assignment.
pass
else:
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(
session,
)
if assigned_agent:
board = (
await Board.objects.by_id(task.board_id).first(session)
@@ -978,16 +1066,17 @@ async def update_task(
"depends_on_task_ids": dep_ids,
"blocked_by_task_ids": blocked_ids,
"is_blocked": bool(blocked_ids),
}
},
)
@router.delete("/{task_id}", response_model=OkResponse)
async def delete_task(
session: AsyncSession = Depends(get_session),
task: Task = Depends(get_task_or_404),
auth: AuthContext = Depends(require_admin_auth),
session: AsyncSession = SESSION_DEP,
task: Task = TASK_DEP,
auth: AuthContext = ADMIN_AUTH_DEP,
) -> OkResponse:
"""Delete a task and related records."""
if task.board_id is None:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
board = await Board.objects.by_id(task.board_id).first(session)
@@ -997,12 +1086,14 @@ async def delete_task(
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
await require_board_access(session, user=auth.user, board=board, write=True)
await crud.delete_where(
session, ActivityEvent, col(ActivityEvent.task_id) == task.id, commit=False
session, ActivityEvent, col(ActivityEvent.task_id) == task.id, commit=False,
)
await crud.delete_where(
session, TaskFingerprint, col(TaskFingerprint.task_id) == task.id, commit=False
session, TaskFingerprint, col(TaskFingerprint.task_id) == task.id, commit=False,
)
await crud.delete_where(
session, Approval, col(Approval.task_id) == task.id, commit=False,
)
await crud.delete_where(session, Approval, col(Approval.task_id) == task.id, commit=False)
await crud.delete_where(
session,
TaskDependency,
@@ -1017,11 +1108,14 @@ async def delete_task(
return OkResponse()
@router.get("/{task_id}/comments", response_model=DefaultLimitOffsetPage[TaskCommentRead])
@router.get(
"/{task_id}/comments", response_model=DefaultLimitOffsetPage[TaskCommentRead],
)
async def list_task_comments(
task: Task = Depends(get_task_or_404),
session: AsyncSession = Depends(get_session),
task: Task = TASK_DEP,
session: AsyncSession = SESSION_DEP,
) -> DefaultLimitOffsetPage[TaskCommentRead]:
"""List comments for a task in chronological order."""
statement = (
select(ActivityEvent)
.where(col(ActivityEvent.task_id) == task.id)
@@ -1032,12 +1126,13 @@ async def list_task_comments(
@router.post("/{task_id}/comments", response_model=TaskCommentRead)
async def create_task_comment(
async def create_task_comment( # noqa: C901, PLR0912
payload: TaskCommentCreate,
task: Task = Depends(get_task_or_404),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
task: Task = TASK_DEP,
session: AsyncSession = SESSION_DEP,
actor: ActorContext = ACTOR_DEP,
) -> ActivityEvent:
"""Create a task comment and notify relevant agents."""
if task.board_id is None:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
if actor.actor_type == "user" and actor.user is not None:
@@ -1045,22 +1140,28 @@ async def create_task_comment(
if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
await require_board_access(session, user=actor.user, board=board, write=True)
if actor.actor_type == "agent" and actor.agent:
if actor.agent.is_board_lead and task.status != "review":
if not await _lead_was_mentioned(session, task, actor.agent) and not _lead_created_task(
task, actor.agent
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=(
"Board leads can only comment during review, when mentioned, or on tasks they created."
),
)
if (
actor.actor_type == "agent"
and actor.agent
and actor.agent.is_board_lead
and task.status != "review"
and not await _lead_was_mentioned(session, task, actor.agent)
and not _lead_created_task(task, actor.agent)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=(
"Board leads can only comment during review, when mentioned, "
"or on tasks they created."
),
)
event = ActivityEvent(
event_type="task.comment",
message=payload.message,
task_id=task.id,
agent_id=actor.agent.id if actor.actor_type == "agent" and actor.agent else None,
agent_id=actor.agent.id
if actor.actor_type == "agent" and actor.agent
else None,
)
session.add(event)
await session.commit()
@@ -1072,17 +1173,27 @@ async def create_task_comment(
if matches_agent_mention(agent, mention_names):
targets[agent.id] = agent
if not mention_names and task.assigned_agent_id:
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(
session,
)
if assigned_agent:
targets[assigned_agent.id] = assigned_agent
if actor.actor_type == "agent" and actor.agent:
targets.pop(actor.agent.id, None)
if targets:
board = await Board.objects.by_id(task.board_id).first(session) if task.board_id else None
board = (
await Board.objects.by_id(task.board_id).first(session)
if task.board_id
else None
)
config = await _gateway_config(session, board) if board else None
if board and config:
snippet = _truncate_snippet(payload.message)
actor_name = actor.agent.name if actor.actor_type == "agent" and actor.agent else "User"
actor_name = (
actor.agent.name
if actor.actor_type == "agent" and actor.agent
else "User"
)
for agent in targets.values():
if not agent.openclaw_session_id:
continue
@@ -1101,15 +1212,14 @@ async def create_task_comment(
f"From: {actor_name}\n\n"
f"{action_line}\n\n"
f"Comment:\n{snippet}\n\n"
"If you are mentioned but not assigned, reply in the task thread but do not change task status."
"If you are mentioned but not assigned, reply in the task "
"thread but do not change task status."
)
try:
with suppress(OpenClawGatewayError):
await _send_agent_task_message(
session_key=agent.openclaw_session_id,
config=config,
agent_name=agent.name,
message=message,
)
except OpenClawGatewayError:
pass
return event

View File

@@ -1,18 +1,28 @@
"""User self-service API endpoints for profile retrieval and updates."""
from __future__ import annotations
from typing import TYPE_CHECKING
from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.auth import AuthContext, get_auth_context
from app.db.session import get_session
from app.models.users import User
from app.schemas.users import UserRead, UserUpdate
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.users import User
router = APIRouter(prefix="/users", tags=["users"])
AUTH_CONTEXT_DEP = Depends(get_auth_context)
SESSION_DEP = Depends(get_session)
@router.get("/me", response_model=UserRead)
async def get_me(auth: AuthContext = Depends(get_auth_context)) -> UserRead:
async def get_me(auth: AuthContext = AUTH_CONTEXT_DEP) -> UserRead:
"""Return the authenticated user's current profile payload."""
if auth.actor_type != "user" or auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
return UserRead.model_validate(auth.user)
@@ -21,9 +31,10 @@ async def get_me(auth: AuthContext = Depends(get_auth_context)) -> UserRead:
@router.patch("/me", response_model=UserRead)
async def update_me(
payload: UserUpdate,
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = SESSION_DEP,
auth: AuthContext = AUTH_CONTEXT_DEP,
) -> UserRead:
"""Apply partial profile updates for the authenticated user."""
if auth.actor_type != "user" or auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
updates = payload.model_dump(exclude_unset=True)

View File

@@ -0,0 +1 @@
"""Core utilities and configuration for the backend service."""

View File

@@ -1,33 +1,44 @@
"""Agent authentication helpers for token-backed API access."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import timedelta
from typing import Literal
from typing import TYPE_CHECKING, Literal
from fastapi import Depends, Header, HTTPException, Request, status
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.agent_tokens import verify_agent_token
from app.core.time import utcnow
from app.db.session import get_session
from app.models.agents import Agent
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
logger = logging.getLogger(__name__)
_LAST_SEEN_TOUCH_INTERVAL = timedelta(seconds=30)
_SAFE_METHODS = frozenset({"GET", "HEAD", "OPTIONS"})
SESSION_DEP = Depends(get_session)
@dataclass
class AgentAuthContext:
"""Authenticated actor payload for agent-originated requests."""
actor_type: Literal["agent"]
agent: Agent
async def _find_agent_for_token(session: AsyncSession, token: str) -> Agent | None:
agents = list(await session.exec(select(Agent).where(col(Agent.agent_token_hash).is_not(None))))
agents = list(
await session.exec(
select(Agent).where(col(Agent.agent_token_hash).is_not(None)),
),
)
for agent in agents:
if agent.agent_token_hash and verify_agent_token(token, agent.agent_token_hash):
return agent
@@ -65,9 +76,11 @@ async def _touch_agent_presence(
calls (task comments, memory updates, etc). Touch presence so the UI reflects
real activity even if the heartbeat loop isn't running.
"""
now = utcnow()
if agent.last_seen_at is not None and now - agent.last_seen_at < _LAST_SEEN_TOUCH_INTERVAL:
if (
agent.last_seen_at is not None
and now - agent.last_seen_at < _LAST_SEEN_TOUCH_INTERVAL
):
return
agent.last_seen_at = now
@@ -86,9 +99,14 @@ async def get_agent_auth_context(
request: Request,
agent_token: str | None = Header(default=None, alias="X-Agent-Token"),
authorization: str | None = Header(default=None, alias="Authorization"),
session: AsyncSession = Depends(get_session),
session: AsyncSession = SESSION_DEP,
) -> AgentAuthContext:
resolved = _resolve_agent_token(agent_token, authorization, accept_authorization=True)
"""Require and validate agent auth token from request headers."""
resolved = _resolve_agent_token(
agent_token,
authorization,
accept_authorization=True,
)
if not resolved:
logger.warning(
"agent auth missing token path=%s x_agent=%s authorization=%s",
@@ -113,8 +131,9 @@ async def get_agent_auth_context_optional(
request: Request,
agent_token: str | None = Header(default=None, alias="X-Agent-Token"),
authorization: str | None = Header(default=None, alias="Authorization"),
session: AsyncSession = Depends(get_session),
session: AsyncSession = SESSION_DEP,
) -> AgentAuthContext | None:
"""Optionally resolve agent auth context from `X-Agent-Token` only."""
resolved = _resolve_agent_token(
agent_token,
authorization,

View File

@@ -1,3 +1,5 @@
"""Token generation and verification helpers for agent authentication."""
from __future__ import annotations
import base64
@@ -10,6 +12,7 @@ SALT_BYTES = 16
def generate_agent_token() -> str:
"""Generate a new URL-safe random token for an agent."""
return secrets.token_urlsafe(32)
@@ -23,12 +26,14 @@ def _b64decode(value: str) -> bytes:
def hash_agent_token(token: str) -> str:
"""Hash an agent token using PBKDF2-HMAC-SHA256 with a random salt."""
salt = secrets.token_bytes(SALT_BYTES)
digest = hashlib.pbkdf2_hmac("sha256", token.encode("utf-8"), salt, ITERATIONS)
return f"pbkdf2_sha256${ITERATIONS}${_b64encode(salt)}${_b64encode(digest)}"
def verify_agent_token(token: str, stored_hash: str) -> bool:
"""Verify a plaintext token against a stored PBKDF2 hash representation."""
try:
algorithm, iterations, salt_b64, digest_b64 = stored_hash.split("$")
except ValueError:
@@ -41,5 +46,10 @@ def verify_agent_token(token: str, stored_hash: str) -> bool:
return False
salt = _b64decode(salt_b64)
expected_digest = _b64decode(digest_b64)
candidate = hashlib.pbkdf2_hmac("sha256", token.encode("utf-8"), salt, iterations_int)
candidate = hashlib.pbkdf2_hmac(
"sha256",
token.encode("utf-8"),
salt,
iterations_int,
)
return hmac.compare_digest(candidate, expected_digest)

View File

@@ -1,32 +1,42 @@
"""User authentication helpers backed by Clerk JWT verification."""
from __future__ import annotations
from dataclasses import dataclass
from functools import lru_cache
from typing import Literal
from typing import TYPE_CHECKING, Literal
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi_clerk_auth import ClerkConfig, ClerkHTTPBearer
from fastapi_clerk_auth import HTTPAuthorizationCredentials as ClerkCredentials
from pydantic import BaseModel, ValidationError
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings
from app.db import crud
from app.db.session import get_session
from app.models.users import User
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
security = HTTPBearer(auto_error=False)
SECURITY_DEP = Depends(security)
SESSION_DEP = Depends(get_session)
CLERK_JWKS_URL_REQUIRED_ERROR = "CLERK_JWKS_URL is not set."
class ClerkTokenPayload(BaseModel):
"""JWT claims payload shape required from Clerk tokens."""
sub: str
@lru_cache
def _build_clerk_http_bearer(auto_error: bool) -> ClerkHTTPBearer:
def _build_clerk_http_bearer(*, auto_error: bool) -> ClerkHTTPBearer:
"""Create and cache the Clerk HTTP bearer guard."""
if not settings.clerk_jwks_url:
raise RuntimeError("CLERK_JWKS_URL is not set.")
raise RuntimeError(CLERK_JWKS_URL_REQUIRED_ERROR)
clerk_config = ClerkConfig(
jwks_url=settings.clerk_jwks_url,
verify_iat=settings.clerk_verify_iat,
@@ -37,12 +47,15 @@ def _build_clerk_http_bearer(auto_error: bool) -> ClerkHTTPBearer:
@dataclass
class AuthContext:
"""Authenticated user context resolved from inbound auth headers."""
actor_type: Literal["user"]
user: User | None = None
def _resolve_clerk_auth(
request: Request, fallback: ClerkCredentials | None
request: Request,
fallback: ClerkCredentials | None,
) -> ClerkCredentials | None:
auth_data = getattr(request.state, "clerk_auth", None)
if isinstance(auth_data, ClerkCredentials):
@@ -59,9 +72,10 @@ def _parse_subject(auth_data: ClerkCredentials | None) -> str | None:
async def get_auth_context(
request: Request,
credentials: HTTPAuthorizationCredentials | None = Depends(security),
session: AsyncSession = Depends(get_session),
credentials: HTTPAuthorizationCredentials | None = SECURITY_DEP,
session: AsyncSession = SESSION_DEP,
) -> AuthContext:
"""Resolve required authenticated user context from Clerk JWT headers."""
if credentials is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
@@ -109,9 +123,10 @@ async def get_auth_context(
async def get_auth_context_optional(
request: Request,
credentials: HTTPAuthorizationCredentials | None = Depends(security),
session: AsyncSession = Depends(get_session),
credentials: HTTPAuthorizationCredentials | None = SECURITY_DEP,
session: AsyncSession = SESSION_DEP,
) -> AuthContext | None:
"""Resolve user context if available, otherwise return `None`."""
if request.headers.get("X-Agent-Token"):
return None
if credentials is None:

View File

@@ -1,3 +1,5 @@
"""Application settings and environment configuration loading."""
from __future__ import annotations
from pathlib import Path
@@ -11,6 +13,8 @@ DEFAULT_ENV_FILE = BACKEND_ROOT / ".env"
class Settings(BaseSettings):
"""Typed runtime configuration sourced from environment variables."""
model_config = SettingsConfigDict(
# Load `backend/.env` regardless of current working directory.
# (Important when running uvicorn from repo root or via a process manager.)
@@ -32,8 +36,8 @@ class Settings(BaseSettings):
base_url: str = ""
# Optional: local directory where the backend is allowed to write "preserved" agent
# workspace files (e.g. USER.md/SELF.md/MEMORY.md). If empty, local writes are disabled
# and provisioning relies on the gateway API.
# workspace files (e.g. USER.md/SELF.md/MEMORY.md). If empty, local
# writes are disabled and provisioning relies on the gateway API.
#
# Security note: do NOT point this at arbitrary system paths in production.
local_agent_workspace_root: str = ""
@@ -48,8 +52,8 @@ class Settings(BaseSettings):
@model_validator(mode="after")
def _defaults(self) -> Self:
# In dev, default to applying Alembic migrations at startup to avoid schema drift
# (e.g. missing newly-added columns).
# In dev, default to applying Alembic migrations at startup to avoid
# schema drift (e.g. missing newly-added columns).
if "db_auto_migrate" not in self.model_fields_set and self.environment == "dev":
self.db_auto_migrate = True
return self

View File

@@ -1,8 +1,13 @@
"""Utilities for parsing human-readable duration schedule strings."""
from __future__ import annotations
import re
_DURATION_RE = re.compile(r"^(?P<num>[1-9]\\d*)\\s*(?P<unit>[smhdw])$", flags=re.IGNORECASE)
_DURATION_RE = re.compile(
r"^(?P<num>[1-9]\\d*)\\s*(?P<unit>[smhdw])$",
flags=re.IGNORECASE,
)
_MULTIPLIERS: dict[str, int] = {
"s": 1,
@@ -11,26 +16,36 @@ _MULTIPLIERS: dict[str, int] = {
"d": 60 * 60 * 24,
"w": 60 * 60 * 24 * 7,
}
_MAX_SCHEDULE_SECONDS = 60 * 60 * 24 * 365 * 10
_ERR_SCHEDULE_REQUIRED = "schedule is required"
_ERR_SCHEDULE_INVALID = (
'Invalid schedule. Expected format like "10m", "1h", "2d", "1w".'
)
_ERR_SCHEDULE_NONPOSITIVE = "Schedule must be greater than 0."
_ERR_SCHEDULE_TOO_LARGE = "Schedule is too large (max 10 years)."
def normalize_every(value: str) -> str:
"""Normalize schedule string to lower-case compact unit form."""
normalized = value.strip().lower().replace(" ", "")
if not normalized:
raise ValueError("schedule is required")
raise ValueError(_ERR_SCHEDULE_REQUIRED)
return normalized
def parse_every_to_seconds(value: str) -> int:
"""Parse compact schedule syntax into a number of seconds."""
normalized = normalize_every(value)
match = _DURATION_RE.match(normalized)
if not match:
raise ValueError('Invalid schedule. Expected format like "10m", "1h", "2d", "1w".')
raise ValueError(_ERR_SCHEDULE_INVALID)
num = int(match.group("num"))
unit = match.group("unit").lower()
seconds = num * _MULTIPLIERS[unit]
if seconds <= 0:
raise ValueError("Schedule must be greater than 0.")
raise ValueError(_ERR_SCHEDULE_NONPOSITIVE)
# Prevent accidental absurd schedules (e.g. 999999999d).
if seconds > 60 * 60 * 24 * 365 * 10:
raise ValueError("Schedule is too large (max 10 years).")
if seconds > _MAX_SCHEDULE_SECONDS:
raise ValueError(_ERR_SCHEDULE_TOO_LARGE)
return seconds

View File

@@ -1,8 +1,10 @@
"""Global exception handlers and request-id middleware for FastAPI."""
from __future__ import annotations
import logging
from collections.abc import Awaitable, Callable
from typing import Any, Final, cast
from typing import TYPE_CHECKING, Any, Final, cast
from uuid import uuid4
from fastapi import FastAPI, Request
@@ -10,7 +12,9 @@ from fastapi.exceptions import RequestValidationError, ResponseValidationError
from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.responses import Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
if TYPE_CHECKING:
from starlette.types import ASGIApp, Message, Receive, Scope, Send
logger = logging.getLogger(__name__)
@@ -20,12 +24,16 @@ ExceptionHandler = Callable[[Request, Exception], Response | Awaitable[Response]
class RequestIdMiddleware:
"""ASGI middleware that ensures every request has a request-id."""
def __init__(self, app: ASGIApp, *, header_name: str = REQUEST_ID_HEADER) -> None:
"""Initialize middleware with app instance and header name."""
self._app = app
self._header_name = header_name
self._header_name_bytes = header_name.lower().encode("latin-1")
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Inject request-id into request state and response headers."""
if scope["type"] != "http":
await self._app(scope, receive, send)
return
@@ -36,8 +44,11 @@ class RequestIdMiddleware:
if message["type"] == "http.response.start":
# Starlette uses `list[tuple[bytes, bytes]]` here.
headers: list[tuple[bytes, bytes]] = message.setdefault("headers", [])
if not any(key.lower() == self._header_name_bytes for key, _ in headers):
headers.append((self._header_name_bytes, request_id.encode("latin-1")))
if not any(
key.lower() == self._header_name_bytes for key, _ in headers
):
request_id_bytes = request_id.encode("latin-1")
headers.append((self._header_name_bytes, request_id_bytes))
await send(message)
await self._app(scope, receive, send_with_request_id)
@@ -62,8 +73,10 @@ class RequestIdMiddleware:
def install_error_handling(app: FastAPI) -> None:
"""Install middleware and exception handlers on the FastAPI app."""
# Important: add request-id middleware last so it's the outermost middleware.
# This ensures it still runs even if another middleware (e.g. CORS preflight) returns early.
# This ensures it still runs even if another middleware
# (e.g. CORS preflight) returns early.
app.add_middleware(RequestIdMiddleware)
app.add_exception_handler(
@@ -88,7 +101,7 @@ def _get_request_id(request: Request) -> str | None:
return None
def _error_payload(*, detail: Any, request_id: str | None) -> dict[str, Any]:
def _error_payload(*, detail: object, request_id: str | None) -> dict[str, object]:
payload: dict[str, Any] = {"detail": detail}
if request_id:
payload["request_id"] = request_id
@@ -96,7 +109,8 @@ def _error_payload(*, detail: Any, request_id: str | None) -> dict[str, Any]:
async def _request_validation_handler(
request: Request, exc: RequestValidationError
request: Request,
exc: RequestValidationError,
) -> JSONResponse:
# `RequestValidationError` is expected user input; don't log at ERROR.
request_id = _get_request_id(request)
@@ -107,7 +121,8 @@ async def _request_validation_handler(
async def _response_validation_handler(
request: Request, exc: ResponseValidationError
request: Request,
exc: ResponseValidationError,
) -> JSONResponse:
request_id = _get_request_id(request)
logger.exception(
@@ -125,7 +140,10 @@ async def _response_validation_handler(
)
async def _http_exception_handler(request: Request, exc: StarletteHTTPException) -> JSONResponse:
async def _http_exception_handler(
request: Request,
exc: StarletteHTTPException,
) -> JSONResponse:
request_id = _get_request_id(request)
return JSONResponse(
status_code=exc.status_code,
@@ -134,11 +152,18 @@ async def _http_exception_handler(request: Request, exc: StarletteHTTPException)
)
async def _unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
async def _unhandled_exception_handler(
request: Request,
_exc: Exception,
) -> JSONResponse:
request_id = _get_request_id(request)
logger.exception(
"unhandled_exception",
extra={"request_id": request_id, "method": request.method, "path": request.url.path},
extra={
"request_id": request_id,
"method": request.method,
"path": request.url.path,
},
)
return JSONResponse(
status_code=500,

View File

@@ -1,3 +1,5 @@
"""Application logging configuration and formatter utilities."""
from __future__ import annotations
import json
@@ -15,7 +17,8 @@ TRACE_LEVEL = 5
logging.addLevelName(TRACE_LEVEL, "TRACE")
def _trace(self: logging.Logger, message: str, *args: Any, **kwargs: Any) -> None:
def _trace(self: logging.Logger, message: str, *args: object, **kwargs: object) -> None:
"""Log a TRACE-level message when the logger is TRACE-enabled."""
if self.isEnabledFor(TRACE_LEVEL):
self._log(TRACE_LEVEL, message, args, **kwargs)
@@ -52,21 +55,31 @@ _STANDARD_LOG_RECORD_ATTRS = {
class AppLogFilter(logging.Filter):
"""Inject app metadata into each log record."""
def __init__(self, app_name: str, version: str) -> None:
"""Initialize the filter with fixed app and version values."""
super().__init__()
self._app_name = app_name
self._version = version
def filter(self, record: logging.LogRecord) -> bool:
"""Attach app metadata fields to each emitted record."""
record.app = self._app_name
record.version = self._version
return True
class JsonFormatter(logging.Formatter):
"""Formatter that serializes log records as compact JSON."""
def format(self, record: logging.LogRecord) -> str:
"""Render a single log record into a JSON string."""
payload: dict[str, Any] = {
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
"timestamp": datetime.fromtimestamp(
record.created,
tz=timezone.utc,
).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
@@ -88,7 +101,10 @@ class JsonFormatter(logging.Formatter):
class KeyValueFormatter(logging.Formatter):
"""Formatter that appends extra fields as `key=value` pairs."""
def format(self, record: logging.LogRecord) -> str:
"""Render a log line with appended non-standard record fields."""
base = super().format(record)
extras = {
key: value
@@ -102,6 +118,8 @@ class KeyValueFormatter(logging.Formatter):
class AppLogger:
"""Centralized logging setup utility for the backend process."""
_configured = False
@classmethod
@@ -111,10 +129,12 @@ class AppLogger:
return level_name, TRACE_LEVEL
if level_name.isdigit():
return level_name, int(level_name)
return level_name, logging._nameToLevel.get(level_name, logging.INFO)
levels = logging.getLevelNamesMapping()
return level_name, levels.get(level_name, logging.INFO)
@classmethod
def configure(cls, *, force: bool = False) -> None:
"""Configure root logging handlers, formatters, and library levels."""
if cls._configured and not force:
return
@@ -127,7 +147,8 @@ class AppLogger:
formatter: logging.Formatter = JsonFormatter()
else:
formatter = KeyValueFormatter(
"%(asctime)s %(levelname)s %(name)s %(message)s app=%(app)s version=%(version)s"
"%(asctime)s %(levelname)s %(name)s %(message)s "
"app=%(app)s version=%(version)s",
)
if settings.log_use_utc:
formatter.converter = time.gmtime
@@ -160,10 +181,12 @@ class AppLogger:
@classmethod
def get_logger(cls, name: str | None = None) -> logging.Logger:
"""Return a logger, ensuring logging has been configured."""
if not cls._configured:
cls.configure()
return logging.getLogger(name)
def configure_logging() -> None:
"""Configure global application logging once during startup."""
AppLogger.configure()

View File

@@ -1,3 +1,5 @@
"""Time-related helpers shared across backend modules."""
from __future__ import annotations
from datetime import UTC, datetime
@@ -5,6 +7,5 @@ from datetime import UTC, datetime
def utcnow() -> datetime:
"""Return a naive UTC datetime without using deprecated datetime.utcnow()."""
# Keep naive UTC values for compatibility with existing DB schema/queries.
return datetime.now(UTC).replace(tzinfo=None)

View File

@@ -1,2 +1,4 @@
"""Application name and version constants."""
APP_NAME = "mission-control"
APP_VERSION = "0.1.0"

View File

@@ -0,0 +1 @@
"""Database helpers and abstractions for backend persistence."""

View File

@@ -1,14 +1,18 @@
"""Typed wrapper around fastapi-pagination for backend query helpers."""
from __future__ import annotations
from collections.abc import Awaitable, Callable, Sequence
from typing import Any, TypeVar, cast
from typing import TYPE_CHECKING, Any, TypeVar, cast
from fastapi_pagination.ext.sqlalchemy import paginate as _paginate
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import Select, SelectOfScalar
from app.schemas.pagination import DefaultLimitOffsetPage
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import Select, SelectOfScalar
T = TypeVar("T")
Transformer = Callable[[Sequence[Any]], Sequence[Any] | Awaitable[Sequence[Any]]]
@@ -20,8 +24,10 @@ async def paginate(
*,
transformer: Transformer | None = None,
) -> DefaultLimitOffsetPage[T]:
# fastapi-pagination is not fully typed (it returns Any), but response_model validation
# ensures runtime correctness. Centralize casts here to keep strict mypy clean.
"""Execute a paginated query and cast to the project page type alias."""
# fastapi-pagination is not fully typed (it returns Any), but response_model
# validation ensures runtime correctness. Centralize casts here to keep strict
# mypy clean.
return cast(
DefaultLimitOffsetPage[T],
await _paginate(session, statement, transformer=transformer),

View File

@@ -1,7 +1,9 @@
"""Model manager descriptor utilities for query-set style access."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Generic, TypeVar
from typing import Generic, TypeVar
from sqlalchemy import false
from sqlmodel import SQLModel, col
@@ -13,41 +15,55 @@ ModelT = TypeVar("ModelT", bound=SQLModel)
@dataclass(frozen=True)
class ModelManager(Generic[ModelT]):
"""Convenience query manager bound to a SQLModel class."""
model: type[ModelT]
id_field: str = "id"
def all(self) -> QuerySet[ModelT]:
"""Return an unfiltered queryset for the bound model."""
return qs(self.model)
def none(self) -> QuerySet[ModelT]:
"""Return a queryset that yields no rows."""
return qs(self.model).filter(false())
def filter(self, *criteria: Any) -> QuerySet[ModelT]:
def filter(self, *criteria: object) -> QuerySet[ModelT]:
"""Return queryset filtered by SQL criteria expressions."""
return self.all().filter(*criteria)
def where(self, *criteria: Any) -> QuerySet[ModelT]:
def where(self, *criteria: object) -> QuerySet[ModelT]:
"""Alias for `filter`."""
return self.filter(*criteria)
def filter_by(self, **kwargs: Any) -> QuerySet[ModelT]:
def filter_by(self, **kwargs: object) -> QuerySet[ModelT]:
"""Return queryset filtered by model field equality values."""
queryset = self.all()
for field_name, value in kwargs.items():
queryset = queryset.filter(col(getattr(self.model, field_name)) == value)
return queryset
def by_id(self, obj_id: Any) -> QuerySet[ModelT]:
def by_id(self, obj_id: object) -> QuerySet[ModelT]:
"""Return queryset filtered by primary identifier field."""
return self.by_field(self.id_field, obj_id)
def by_ids(self, obj_ids: list[Any] | tuple[Any, ...] | set[Any]) -> QuerySet[ModelT]:
def by_ids(
self,
obj_ids: list[object] | tuple[object, ...] | set[object],
) -> QuerySet[ModelT]:
"""Return queryset filtered by a set/list/tuple of identifiers."""
return self.by_field_in(self.id_field, obj_ids)
def by_field(self, field_name: str, value: Any) -> QuerySet[ModelT]:
def by_field(self, field_name: str, value: object) -> QuerySet[ModelT]:
"""Return queryset filtered by a single field equality check."""
return self.filter(col(getattr(self.model, field_name)) == value)
def by_field_in(
self,
field_name: str,
values: list[Any] | tuple[Any, ...] | set[Any],
values: list[object] | tuple[object, ...] | set[object],
) -> QuerySet[ModelT]:
"""Return queryset filtered by `field IN values` semantics."""
seq = tuple(values)
if not seq:
return self.none()
@@ -55,5 +71,8 @@ class ModelManager(Generic[ModelT]):
class ManagerDescriptor(Generic[ModelT]):
"""Descriptor that exposes a model-bound `ModelManager` as `.objects`."""
def __get__(self, instance: object, owner: type[ModelT]) -> ModelManager[ModelT]:
"""Return a fresh manager bound to the owning model class."""
return ModelManager(owner)

View File

@@ -1,50 +1,67 @@
"""Lightweight immutable query-set wrapper for SQLModel statements."""
from __future__ import annotations
from dataclasses import dataclass, replace
from typing import Any, Generic, TypeVar
from typing import TYPE_CHECKING, Generic, TypeVar
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar
ModelT = TypeVar("ModelT")
@dataclass(frozen=True)
class QuerySet(Generic[ModelT]):
"""Composable immutable wrapper around a SQLModel scalar select statement."""
statement: SelectOfScalar[ModelT]
def filter(self, *criteria: Any) -> QuerySet[ModelT]:
def filter(self, *criteria: object) -> QuerySet[ModelT]:
"""Return a new queryset with additional SQL criteria."""
return replace(self, statement=self.statement.where(*criteria))
def where(self, *criteria: Any) -> QuerySet[ModelT]:
def where(self, *criteria: object) -> QuerySet[ModelT]:
"""Alias for `filter` to mirror SQLAlchemy naming."""
return self.filter(*criteria)
def filter_by(self, **kwargs: Any) -> QuerySet[ModelT]:
def filter_by(self, **kwargs: object) -> QuerySet[ModelT]:
"""Return a new queryset filtered by keyword-equality criteria."""
statement = self.statement.filter_by(**kwargs)
return replace(self, statement=statement)
def order_by(self, *ordering: Any) -> QuerySet[ModelT]:
def order_by(self, *ordering: object) -> QuerySet[ModelT]:
"""Return a new queryset with ordering clauses applied."""
return replace(self, statement=self.statement.order_by(*ordering))
def limit(self, value: int) -> QuerySet[ModelT]:
"""Return a new queryset with a SQL row limit."""
return replace(self, statement=self.statement.limit(value))
def offset(self, value: int) -> QuerySet[ModelT]:
"""Return a new queryset with a SQL row offset."""
return replace(self, statement=self.statement.offset(value))
async def all(self, session: AsyncSession) -> list[ModelT]:
"""Execute and return all rows for the current queryset."""
return list(await session.exec(self.statement))
async def first(self, session: AsyncSession) -> ModelT | None:
"""Execute and return the first row, if available."""
return (await session.exec(self.statement)).first()
async def one_or_none(self, session: AsyncSession) -> ModelT | None:
"""Execute and return one row or `None`."""
return (await session.exec(self.statement)).one_or_none()
async def exists(self, session: AsyncSession) -> bool:
"""Return whether the queryset yields at least one row."""
return await self.limit(1).first(session) is not None
def qs(model: type[ModelT]) -> QuerySet[ModelT]:
"""Create a base queryset for a SQLModel class."""
return QuerySet(select(model))

View File

@@ -1,8 +1,10 @@
"""Database engine, session factory, and startup migration helpers."""
from __future__ import annotations
import logging
from collections.abc import AsyncGenerator
from pathlib import Path
from typing import TYPE_CHECKING
import anyio
from alembic import command
@@ -15,6 +17,9 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app import models as _models
from app.core.config import settings
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
# Import model modules so SQLModel metadata is fully registered at startup.
_MODEL_REGISTRY = _models
@@ -48,12 +53,14 @@ def _alembic_config() -> Config:
def run_migrations() -> None:
"""Apply Alembic migrations to the latest revision."""
logger.info("Running database migrations.")
command.upgrade(_alembic_config(), "head")
logger.info("Database migrations complete.")
async def init_db() -> None:
"""Initialize database schema, running migrations when configured."""
if settings.db_auto_migrate:
versions_dir = Path(__file__).resolve().parents[2] / "migrations" / "versions"
if any(versions_dir.glob("*.py")):
@@ -67,6 +74,7 @@ async def init_db() -> None:
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""Yield a request-scoped async DB session with safe rollback on errors."""
async with async_session_maker() as session:
try:
yield session

View File

@@ -0,0 +1 @@
"""External system integration clients and protocol adapters."""

View File

@@ -1,3 +1,5 @@
"""OpenClaw gateway protocol constants shared across integration layers."""
from __future__ import annotations
PROTOCOL_VERSION = 3
@@ -116,4 +118,5 @@ GATEWAY_EVENTS_SET = frozenset(GATEWAY_EVENTS)
def is_known_gateway_method(method: str) -> bool:
"""Return whether a method name is part of the known base gateway methods."""
return method in GATEWAY_METHODS_SET

View File

@@ -1,7 +1,9 @@
"""FastAPI application entrypoint and router wiring for the backend."""
from __future__ import annotations
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING
from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware
@@ -29,11 +31,15 @@ from app.core.error_handling import install_error_handling
from app.core.logging import configure_logging
from app.db.session import init_db
if TYPE_CHECKING:
from collections.abc import AsyncIterator
configure_logging()
@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
"""Initialize application resources before serving requests."""
await init_db()
yield
@@ -55,16 +61,19 @@ install_error_handling(app)
@app.get("/health")
def health() -> dict[str, bool]:
"""Lightweight liveness probe endpoint."""
return {"ok": True}
@app.get("/healthz")
def healthz() -> dict[str, bool]:
"""Alias liveness probe endpoint for platform compatibility."""
return {"ok": True}
@app.get("/readyz")
def readyz() -> dict[str, bool]:
"""Readiness probe endpoint for service orchestration checks."""
return {"ok": True}

View File

@@ -1,3 +1,5 @@
"""Model exports for SQLAlchemy/SQLModel metadata discovery."""
from app.models.activity_events import ActivityEvent
from app.models.agents import Agent
from app.models.approvals import Approval

View File

@@ -1,6 +1,8 @@
"""Activity event model persisted for audit and feed use-cases."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4
from sqlmodel import Field
@@ -10,6 +12,8 @@ from app.models.base import QueryModel
class ActivityEvent(QueryModel, table=True):
"""Discrete activity event tied to tasks and agents."""
__tablename__ = "activity_events"
id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -1,6 +1,8 @@
"""Agent model representing autonomous actors assigned to boards."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from typing import Any
from uuid import UUID, uuid4
@@ -12,6 +14,8 @@ from app.models.base import QueryModel
class Agent(QueryModel, table=True):
"""Agent configuration and lifecycle state persisted in the database."""
__tablename__ = "agents"
id: UUID = Field(default_factory=uuid4, primary_key=True)
@@ -20,8 +24,14 @@ class Agent(QueryModel, table=True):
status: str = Field(default="provisioning", index=True)
openclaw_session_id: str | None = Field(default=None, index=True)
agent_token_hash: str | None = Field(default=None, index=True)
heartbeat_config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
identity_profile: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
heartbeat_config: dict[str, Any] | None = Field(
default=None,
sa_column=Column(JSON),
)
identity_profile: dict[str, Any] | None = Field(
default=None,
sa_column=Column(JSON),
)
identity_template: str | None = Field(default=None, sa_column=Column(Text))
soul_template: str | None = Field(default=None, sa_column=Column(Text))
provision_requested_at: datetime | None = Field(default=None)

View File

@@ -1,6 +1,8 @@
"""Approval model storing pending and resolved approval actions."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4
from sqlalchemy import JSON, Column
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
class Approval(QueryModel, table=True):
"""Approval request and decision metadata for gated operations."""
__tablename__ = "approvals"
id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -1,3 +1,5 @@
"""Base model mixins and shared SQLModel abstractions."""
from __future__ import annotations
from typing import ClassVar, Self
@@ -8,4 +10,6 @@ from app.db.query_manager import ManagerDescriptor
class QueryModel(SQLModel, table=False):
"""Base SQLModel with a shared query manager descriptor."""
objects: ClassVar[ManagerDescriptor[Self]] = ManagerDescriptor()

View File

@@ -1,6 +1,8 @@
"""Board-group scoped memory entries for shared context."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4
from sqlalchemy import JSON, Column
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
class BoardGroupMemory(QueryModel, table=True):
"""Persisted memory items associated with a board group."""
__tablename__ = "board_group_memory"
id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -1,6 +1,8 @@
"""Board group model used to organize boards inside organizations."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4
from sqlmodel import Field
@@ -10,6 +12,8 @@ from app.models.tenancy import TenantScoped
class BoardGroup(TenantScoped, table=True):
"""Logical grouping container for boards within an organization."""
__tablename__ = "board_groups"
id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -1,6 +1,8 @@
"""Board-level memory entries for persistent contextual state."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4
from sqlalchemy import JSON, Column
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
class BoardMemory(QueryModel, table=True):
"""Persisted memory item attached directly to a board."""
__tablename__ = "board_memory"
id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -1,6 +1,8 @@
"""Board onboarding session model for guided setup state."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4
from sqlalchemy import JSON, Column
@@ -11,13 +13,18 @@ from app.models.base import QueryModel
class BoardOnboardingSession(QueryModel, table=True):
"""Persisted onboarding conversation and draft goal data for a board."""
__tablename__ = "board_onboarding_sessions"
id: UUID = Field(default_factory=uuid4, primary_key=True)
board_id: UUID = Field(foreign_key="boards.id", index=True)
session_key: str
status: str = Field(default="active", index=True)
messages: list[dict[str, object]] | None = Field(default=None, sa_column=Column(JSON))
messages: list[dict[str, object]] | None = Field(
default=None,
sa_column=Column(JSON),
)
draft_goal: dict[str, object] | None = Field(default=None, sa_column=Column(JSON))
created_at: datetime = Field(default_factory=utcnow)
updated_at: datetime = Field(default_factory=utcnow)

View File

@@ -1,6 +1,8 @@
"""Board model for organization workspaces and goal configuration."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4
from sqlalchemy import JSON, Column
@@ -11,6 +13,8 @@ from app.models.tenancy import TenantScoped
class Board(TenantScoped, table=True):
"""Primary board entity grouping tasks, agents, and goal metadata."""
__tablename__ = "boards"
id: UUID = Field(default_factory=uuid4, primary_key=True)
@@ -18,10 +22,17 @@ class Board(TenantScoped, table=True):
name: str
slug: str = Field(index=True)
gateway_id: UUID | None = Field(default=None, foreign_key="gateways.id", index=True)
board_group_id: UUID | None = Field(default=None, foreign_key="board_groups.id", index=True)
board_group_id: UUID | None = Field(
default=None,
foreign_key="board_groups.id",
index=True,
)
board_type: str = Field(default="goal", index=True)
objective: str | None = None
success_metrics: dict[str, object] | None = Field(default=None, sa_column=Column(JSON))
success_metrics: dict[str, object] | None = Field(
default=None,
sa_column=Column(JSON),
)
target_date: datetime | None = None
goal_confirmed: bool = Field(default=False)
goal_source: str | None = None

View File

@@ -1,6 +1,8 @@
"""Gateway model storing organization-level gateway integration metadata."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4
from sqlmodel import Field
@@ -10,6 +12,8 @@ from app.models.base import QueryModel
class Gateway(QueryModel, table=True):
"""Configured external gateway endpoint and authentication settings."""
__tablename__ = "gateways"
id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -1,6 +1,8 @@
"""Board-level access grants assigned to organization members."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
class OrganizationBoardAccess(QueryModel, table=True):
"""Member-specific board permissions within an organization."""
__tablename__ = "organization_board_access"
__table_args__ = (
UniqueConstraint(
@@ -21,7 +25,10 @@ class OrganizationBoardAccess(QueryModel, table=True):
)
id: UUID = Field(default_factory=uuid4, primary_key=True)
organization_member_id: UUID = Field(foreign_key="organization_members.id", index=True)
organization_member_id: UUID = Field(
foreign_key="organization_members.id",
index=True,
)
board_id: UUID = Field(foreign_key="boards.id", index=True)
can_read: bool = Field(default=True)
can_write: bool = Field(default=False)

View File

@@ -1,6 +1,8 @@
"""Board access grants attached to pending organization invites."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
class OrganizationInviteBoardAccess(QueryModel, table=True):
"""Invite-specific board permissions applied after invite acceptance."""
__tablename__ = "organization_invite_board_access"
__table_args__ = (
UniqueConstraint(
@@ -21,7 +25,10 @@ class OrganizationInviteBoardAccess(QueryModel, table=True):
)
id: UUID = Field(default_factory=uuid4, primary_key=True)
organization_invite_id: UUID = Field(foreign_key="organization_invites.id", index=True)
organization_invite_id: UUID = Field(
foreign_key="organization_invites.id",
index=True,
)
board_id: UUID = Field(foreign_key="boards.id", index=True)
can_read: bool = Field(default=True)
can_write: bool = Field(default=False)

View File

@@ -1,6 +1,8 @@
"""Organization invite model for email-based tenant membership flow."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
class OrganizationInvite(QueryModel, table=True):
"""Invitation record granting prospective organization access."""
__tablename__ = "organization_invites"
__table_args__ = (UniqueConstraint("token", name="uq_org_invites_token"),)
@@ -21,8 +25,16 @@ class OrganizationInvite(QueryModel, table=True):
role: str = Field(default="member", index=True)
all_boards_read: bool = Field(default=False)
all_boards_write: bool = Field(default=False)
created_by_user_id: UUID | None = Field(default=None, foreign_key="users.id", index=True)
accepted_by_user_id: UUID | None = Field(default=None, foreign_key="users.id", index=True)
created_by_user_id: UUID | None = Field(
default=None,
foreign_key="users.id",
index=True,
)
accepted_by_user_id: UUID | None = Field(
default=None,
foreign_key="users.id",
index=True,
)
accepted_at: datetime | None = None
created_at: datetime = Field(default_factory=utcnow)
updated_at: datetime = Field(default_factory=utcnow)

View File

@@ -1,6 +1,8 @@
"""Organization membership model with role and board-access flags."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
class OrganizationMember(QueryModel, table=True):
"""Membership row linking a user to an organization and permissions."""
__tablename__ = "organization_members"
__table_args__ = (
UniqueConstraint(

View File

@@ -1,6 +1,8 @@
"""Organization model representing top-level tenant entities."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
class Organization(QueryModel, table=True):
"""Top-level organization tenant record."""
__tablename__ = "organizations"
__table_args__ = (UniqueConstraint("name", name="uq_organizations_name"),)

View File

@@ -1,6 +1,8 @@
"""Task dependency edge model for board-local dependency graphs."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4
from sqlalchemy import CheckConstraint, UniqueConstraint
@@ -11,6 +13,8 @@ from app.models.tenancy import TenantScoped
class TaskDependency(TenantScoped, table=True):
"""Directed dependency edge between two tasks in the same board."""
__tablename__ = "task_dependencies"
__table_args__ = (
UniqueConstraint(

View File

@@ -1,6 +1,8 @@
"""Task fingerprint model for duplicate/task-linking operations."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4
from sqlmodel import Field
@@ -10,6 +12,8 @@ from app.models.base import QueryModel
class TaskFingerprint(QueryModel, table=True):
"""Hashed task-content fingerprint associated with a board and task."""
__tablename__ = "task_fingerprints"
id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -1,6 +1,8 @@
"""Task model representing board work items and execution metadata."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4
from sqlmodel import Field
@@ -10,6 +12,8 @@ from app.models.tenancy import TenantScoped
class Task(TenantScoped, table=True):
"""Board-scoped task entity with ownership, status, and timing fields."""
__tablename__ = "tasks"
id: UUID = Field(default_factory=uuid4, primary_key=True)
@@ -22,8 +26,16 @@ class Task(TenantScoped, table=True):
due_at: datetime | None = None
in_progress_at: datetime | None = None
created_by_user_id: UUID | None = Field(default=None, foreign_key="users.id", index=True)
assigned_agent_id: UUID | None = Field(default=None, foreign_key="agents.id", index=True)
created_by_user_id: UUID | None = Field(
default=None,
foreign_key="users.id",
index=True,
)
assigned_agent_id: UUID | None = Field(
default=None,
foreign_key="agents.id",
index=True,
)
auto_created: bool = Field(default=False)
auto_reason: str | None = None

View File

@@ -1,7 +1,9 @@
"""Shared tenancy-scoped model base classes."""
from __future__ import annotations
from app.models.base import QueryModel
class TenantScoped(QueryModel, table=False):
pass
"""Base class for models constrained to a tenant/organization scope."""

View File

@@ -1,3 +1,5 @@
"""User model storing identity and profile preferences."""
from __future__ import annotations
from uuid import UUID, uuid4
@@ -8,6 +10,8 @@ from app.models.base import QueryModel
class User(QueryModel, table=True):
"""Application user account and profile attributes."""
__tablename__ = "users"
id: UUID = Field(default_factory=uuid4, primary_key=True)
@@ -21,5 +25,7 @@ class User(QueryModel, table=True):
context: str | None = None
is_super_admin: bool = Field(default=False)
active_organization_id: UUID | None = Field(
default=None, foreign_key="organizations.id", index=True
default=None,
foreign_key="organizations.id",
index=True,
)

View File

@@ -1,3 +1,5 @@
"""Public schema exports shared across API route modules."""
from app.schemas.activity_events import ActivityEventRead
from app.schemas.agents import AgentCreate, AgentRead, AgentUpdate
from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalUpdate

View File

@@ -1,12 +1,16 @@
"""Response schemas for activity events and task-comment feed items."""
from __future__ import annotations
from datetime import datetime
from uuid import UUID
from datetime import datetime # noqa: TCH003
from uuid import UUID # noqa: TCH003
from sqlmodel import SQLModel
class ActivityEventRead(SQLModel):
"""Serialized activity event payload returned by activity endpoints."""
id: UUID
event_type: str
message: str | None
@@ -16,6 +20,8 @@ class ActivityEventRead(SQLModel):
class ActivityTaskCommentFeedItemRead(SQLModel):
"""Denormalized task-comment feed item enriched with task and board fields."""
id: UUID
created_at: datetime
message: str | None

View File

@@ -1,16 +1,21 @@
"""Schemas for approval create/update/read API payloads."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from typing import Literal, Self
from uuid import UUID
from uuid import UUID # noqa: TCH003
from pydantic import model_validator
from sqlmodel import SQLModel
ApprovalStatus = Literal["pending", "approved", "rejected"]
STATUS_REQUIRED_ERROR = "status is required"
class ApprovalBase(SQLModel):
"""Shared approval fields used across create/read payloads."""
action_type: str
task_id: UUID | None = None
payload: dict[str, object] | None = None
@@ -20,20 +25,27 @@ class ApprovalBase(SQLModel):
class ApprovalCreate(ApprovalBase):
"""Payload for creating a new approval request."""
agent_id: UUID | None = None
class ApprovalUpdate(SQLModel):
"""Payload for mutating approval status."""
status: ApprovalStatus | None = None
@model_validator(mode="after")
def validate_status(self) -> Self:
"""Ensure explicitly provided `status` is not null."""
if "status" in self.model_fields_set and self.status is None:
raise ValueError("status is required")
raise ValueError(STATUS_REQUIRED_ERROR)
return self
class ApprovalRead(ApprovalBase):
"""Approval payload returned from read endpoints."""
id: UUID
board_id: UUID
agent_id: UUID | None = None

View File

@@ -1,13 +1,18 @@
"""Schemas for applying heartbeat settings to board-group agents."""
from __future__ import annotations
from typing import Any
from uuid import UUID
from uuid import UUID # noqa: TCH003
from sqlmodel import SQLModel
class BoardGroupHeartbeatApply(SQLModel):
# Heartbeat cadence string understood by the OpenClaw gateway (e.g. "2m", "10m", "30m").
"""Request payload for heartbeat policy updates."""
# Heartbeat cadence string understood by the OpenClaw gateway
# (e.g. "2m", "10m", "30m").
every: str
# Optional heartbeat target (most deployments use "none").
target: str | None = None
@@ -15,6 +20,8 @@ class BoardGroupHeartbeatApply(SQLModel):
class BoardGroupHeartbeatApplyResult(SQLModel):
"""Result payload describing agents updated by a heartbeat request."""
board_group_id: UUID
requested: dict[str, Any]
updated_agent_ids: list[UUID]

View File

@@ -1,14 +1,18 @@
"""Schemas for board-group memory create/read API payloads."""
from __future__ import annotations
from datetime import datetime
from uuid import UUID
from datetime import datetime # noqa: TCH003
from uuid import UUID # noqa: TCH003
from sqlmodel import SQLModel
from app.schemas.common import NonEmptyStr
from app.schemas.common import NonEmptyStr # noqa: TCH001
class BoardGroupMemoryCreate(SQLModel):
"""Payload for creating a board-group memory entry."""
# For writes, reject blank/whitespace-only content.
content: NonEmptyStr
tags: list[str] | None = None
@@ -16,9 +20,12 @@ class BoardGroupMemoryCreate(SQLModel):
class BoardGroupMemoryRead(SQLModel):
"""Serialized board-group memory entry returned from read endpoints."""
id: UUID
board_group_id: UUID
# For reads, allow legacy rows that may have empty content (avoid response validation 500s).
# For reads, allow legacy rows that may have empty content
# (avoid response validation 500s).
content: str
tags: list[str] | None = None
source: str | None = None

View File

@@ -1,28 +1,36 @@
"""Schemas for board-group create/update/read API operations."""
from __future__ import annotations
from datetime import datetime
from uuid import UUID
from datetime import datetime # noqa: TCH003
from uuid import UUID # noqa: TCH003
from sqlmodel import SQLModel
class BoardGroupBase(SQLModel):
"""Shared board-group fields for create/read operations."""
name: str
slug: str
description: str | None = None
class BoardGroupCreate(BoardGroupBase):
pass
"""Payload for creating a board group."""
class BoardGroupUpdate(SQLModel):
"""Payload for partial board-group updates."""
name: str | None = None
slug: str | None = None
description: str | None = None
class BoardGroupRead(BoardGroupBase):
"""Board-group payload returned from read endpoints."""
id: UUID
organization_id: UUID
created_at: datetime

View File

@@ -1,14 +1,18 @@
"""Schemas for board memory create/read API payloads."""
from __future__ import annotations
from datetime import datetime
from uuid import UUID
from datetime import datetime # noqa: TCH003
from uuid import UUID # noqa: TCH003
from sqlmodel import SQLModel
from app.schemas.common import NonEmptyStr
from app.schemas.common import NonEmptyStr # noqa: TCH001
class BoardMemoryCreate(SQLModel):
"""Payload for creating a board memory entry."""
# For writes, reject blank/whitespace-only content.
content: NonEmptyStr
tags: list[str] | None = None
@@ -16,9 +20,12 @@ class BoardMemoryCreate(SQLModel):
class BoardMemoryRead(SQLModel):
"""Serialized board memory entry returned from read endpoints."""
id: UUID
board_id: UUID
# For reads, allow legacy rows that may have empty content (avoid response validation 500s).
# For reads, allow legacy rows that may have empty content
# (avoid response validation 500s).
content: str
tags: list[str] | None = None
source: str | None = None

View File

@@ -1,14 +1,23 @@
"""Schemas for board create/update/read API operations."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from typing import Self
from uuid import UUID
from uuid import UUID # noqa: TCH003
from pydantic import model_validator
from sqlmodel import SQLModel
_ERR_GOAL_FIELDS_REQUIRED = (
"Confirmed goal boards require objective and success_metrics"
)
_ERR_GATEWAY_REQUIRED = "gateway_id is required"
class BoardBase(SQLModel):
"""Shared board fields used across create and read payloads."""
name: str
slug: str
gateway_id: UUID | None = None
@@ -22,17 +31,25 @@ class BoardBase(SQLModel):
class BoardCreate(BoardBase):
"""Payload for creating a board."""
gateway_id: UUID
@model_validator(mode="after")
def validate_goal_fields(self) -> Self:
if self.board_type == "goal" and self.goal_confirmed:
if not self.objective or not self.success_metrics:
raise ValueError("Confirmed goal boards require objective and success_metrics")
"""Require goal details when creating a confirmed goal board."""
if (
self.board_type == "goal"
and self.goal_confirmed
and (not self.objective or not self.success_metrics)
):
raise ValueError(_ERR_GOAL_FIELDS_REQUIRED)
return self
class BoardUpdate(SQLModel):
"""Payload for partial board updates."""
name: str | None = None
slug: str | None = None
gateway_id: UUID | None = None
@@ -46,13 +63,16 @@ class BoardUpdate(SQLModel):
@model_validator(mode="after")
def validate_gateway_id(self) -> Self:
"""Reject explicit null gateway IDs in patch payloads."""
# Treat explicit null like "unset" is invalid for patch updates.
if "gateway_id" in self.model_fields_set and self.gateway_id is None:
raise ValueError("gateway_id is required")
raise ValueError(_ERR_GATEWAY_REQUIRED)
return self
class BoardRead(BoardBase):
"""Board payload returned from read endpoints."""
id: UUID
organization_id: UUID
created_at: datetime

View File

@@ -1,3 +1,5 @@
"""Common reusable schema primitives and simple API response envelopes."""
from __future__ import annotations
from typing import Annotated
@@ -5,9 +7,12 @@ from typing import Annotated
from pydantic import StringConstraints
from sqlmodel import SQLModel
# Reusable string type for request payloads where blank/whitespace-only values are invalid.
# Reusable string type for request payloads where blank/whitespace-only values
# are invalid.
NonEmptyStr = Annotated[str, StringConstraints(strip_whitespace=True, min_length=1)]
class OkResponse(SQLModel):
"""Standard success response payload."""
ok: bool = True

View File

@@ -1,12 +1,18 @@
"""Structured error payload schemas used by API responses."""
from __future__ import annotations
from sqlmodel import Field, SQLModel
class BlockedTaskDetail(SQLModel):
"""Error detail payload listing blocking dependency task identifiers."""
message: str
blocked_by_task_ids: list[str] = Field(default_factory=list)
class BlockedTaskError(SQLModel):
"""Top-level blocked-task error response envelope."""
detail: BlockedTaskDetail

View File

@@ -1,15 +1,21 @@
"""Schemas for gateway passthrough API request and response payloads."""
from __future__ import annotations
from sqlmodel import SQLModel
from app.schemas.common import NonEmptyStr
from app.schemas.common import NonEmptyStr # noqa: TCH001
class GatewaySessionMessageRequest(SQLModel):
"""Request payload for sending a message into a gateway session."""
content: NonEmptyStr
class GatewayResolveQuery(SQLModel):
"""Query parameters used to resolve which gateway to target."""
board_id: str | None = None
gateway_url: str | None = None
gateway_token: str | None = None
@@ -17,6 +23,8 @@ class GatewayResolveQuery(SQLModel):
class GatewaysStatusResponse(SQLModel):
"""Aggregated gateway status response including session metadata."""
connected: bool
gateway_url: str
sessions_count: int | None = None
@@ -28,20 +36,28 @@ class GatewaysStatusResponse(SQLModel):
class GatewaySessionsResponse(SQLModel):
"""Gateway sessions list response payload."""
sessions: list[object]
main_session_key: str | None = None
main_session: object | None = None
class GatewaySessionResponse(SQLModel):
"""Single gateway session response payload."""
session: object
class GatewaySessionHistoryResponse(SQLModel):
"""Gateway session history response payload."""
history: list[object]
class GatewayCommandsResponse(SQLModel):
"""Gateway command catalog and protocol metadata."""
protocol_version: int
methods: list[str]
events: list[str]

View File

@@ -1,24 +1,38 @@
"""Schemas for gateway-main and lead-agent coordination endpoints."""
from __future__ import annotations
from typing import Literal
from uuid import UUID
from uuid import UUID # noqa: TCH003
from sqlmodel import Field, SQLModel
from app.schemas.common import NonEmptyStr
from app.schemas.common import NonEmptyStr # noqa: TCH001
def _lead_reply_tags() -> list[str]:
return ["gateway_main", "lead_reply"]
def _user_reply_tags() -> list[str]:
return ["gateway_main", "user_reply"]
class GatewayLeadMessageRequest(SQLModel):
"""Request payload for sending a message to a board lead agent."""
kind: Literal["question", "handoff"] = "question"
correlation_id: str | None = None
content: NonEmptyStr
# How the lead should reply (defaults are interpreted by templates).
reply_tags: list[str] = Field(default_factory=lambda: ["gateway_main", "lead_reply"])
reply_tags: list[str] = Field(default_factory=_lead_reply_tags)
reply_source: str | None = "lead_to_gateway_main"
class GatewayLeadMessageResponse(SQLModel):
"""Response payload for a lead-message dispatch attempt."""
ok: bool = True
board_id: UUID
lead_agent_id: UUID | None = None
@@ -27,15 +41,19 @@ class GatewayLeadMessageResponse(SQLModel):
class GatewayLeadBroadcastRequest(SQLModel):
"""Request payload for broadcasting a message to multiple board leads."""
kind: Literal["question", "handoff"] = "question"
correlation_id: str | None = None
content: NonEmptyStr
board_ids: list[UUID] | None = None
reply_tags: list[str] = Field(default_factory=lambda: ["gateway_main", "lead_reply"])
reply_tags: list[str] = Field(default_factory=_lead_reply_tags)
reply_source: str | None = "lead_to_gateway_main"
class GatewayLeadBroadcastBoardResult(SQLModel):
"""Per-board result entry for a lead broadcast operation."""
board_id: UUID
lead_agent_id: UUID | None = None
lead_agent_name: str | None = None
@@ -44,6 +62,8 @@ class GatewayLeadBroadcastBoardResult(SQLModel):
class GatewayLeadBroadcastResponse(SQLModel):
"""Aggregate response for a lead broadcast operation."""
ok: bool = True
sent: int = 0
failed: int = 0
@@ -51,16 +71,21 @@ class GatewayLeadBroadcastResponse(SQLModel):
class GatewayMainAskUserRequest(SQLModel):
"""Request payload for asking the end user via a main gateway agent."""
correlation_id: str | None = None
content: NonEmptyStr
preferred_channel: str | None = None
# How the main agent should reply back into Mission Control (defaults interpreted by templates).
reply_tags: list[str] = Field(default_factory=lambda: ["gateway_main", "user_reply"])
# How the main agent should reply back into Mission Control
# (defaults interpreted by templates).
reply_tags: list[str] = Field(default_factory=_user_reply_tags)
reply_source: str | None = "user_via_gateway_main"
class GatewayMainAskUserResponse(SQLModel):
"""Response payload for user-question dispatch via gateway main agent."""
ok: bool = True
board_id: UUID
main_agent_id: UUID | None = None

View File

@@ -1,14 +1,17 @@
"""Schemas for gateway CRUD and template-sync API payloads."""
from __future__ import annotations
from datetime import datetime
from typing import Any
from uuid import UUID
from datetime import datetime # noqa: TCH003
from uuid import UUID # noqa: TCH003
from pydantic import field_validator
from sqlmodel import Field, SQLModel
class GatewayBase(SQLModel):
"""Shared gateway fields used across create/read payloads."""
name: str
url: str
main_session_key: str
@@ -16,11 +19,14 @@ class GatewayBase(SQLModel):
class GatewayCreate(GatewayBase):
"""Payload for creating a gateway configuration."""
token: str | None = None
@field_validator("token", mode="before")
@classmethod
def normalize_token(cls, value: Any) -> Any:
def normalize_token(cls, value: object) -> str | None | object:
"""Normalize empty/whitespace tokens to `None`."""
if value is None:
return None
if isinstance(value, str):
@@ -30,6 +36,8 @@ class GatewayCreate(GatewayBase):
class GatewayUpdate(SQLModel):
"""Payload for partial gateway updates."""
name: str | None = None
url: str | None = None
token: str | None = None
@@ -38,7 +46,8 @@ class GatewayUpdate(SQLModel):
@field_validator("token", mode="before")
@classmethod
def normalize_token(cls, value: Any) -> Any:
def normalize_token(cls, value: object) -> str | None | object:
"""Normalize empty/whitespace tokens to `None`."""
if value is None:
return None
if isinstance(value, str):
@@ -48,6 +57,8 @@ class GatewayUpdate(SQLModel):
class GatewayRead(GatewayBase):
"""Gateway payload returned from read endpoints."""
id: UUID
organization_id: UUID
token: str | None = None
@@ -56,6 +67,8 @@ class GatewayRead(GatewayBase):
class GatewayTemplatesSyncError(SQLModel):
"""Per-agent error entry from a gateway template sync operation."""
agent_id: UUID | None = None
agent_name: str | None = None
board_id: UUID | None = None
@@ -63,6 +76,8 @@ class GatewayTemplatesSyncError(SQLModel):
class GatewayTemplatesSyncResult(SQLModel):
"""Summary payload returned by gateway template sync endpoints."""
gateway_id: UUID
include_main: bool
reset_sessions: bool

View File

@@ -1,17 +1,23 @@
"""Dashboard metrics schemas for KPI and time-series API responses."""
from __future__ import annotations
from datetime import datetime
from datetime import datetime # noqa: TCH003
from typing import Literal
from sqlmodel import SQLModel
class DashboardSeriesPoint(SQLModel):
"""Single numeric time-series point."""
period: datetime
value: float
class DashboardWipPoint(SQLModel):
"""Work-in-progress point split by task status buckets."""
period: datetime
inbox: int
in_progress: int
@@ -19,28 +25,38 @@ class DashboardWipPoint(SQLModel):
class DashboardRangeSeries(SQLModel):
"""Series payload for a single range/bucket combination."""
range: Literal["24h", "7d"]
bucket: Literal["hour", "day"]
points: list[DashboardSeriesPoint]
class DashboardWipRangeSeries(SQLModel):
"""WIP series payload for a single range/bucket combination."""
range: Literal["24h", "7d"]
bucket: Literal["hour", "day"]
points: list[DashboardWipPoint]
class DashboardSeriesSet(SQLModel):
"""Primary vs comparison pair for generic series metrics."""
primary: DashboardRangeSeries
comparison: DashboardRangeSeries
class DashboardWipSeriesSet(SQLModel):
"""Primary vs comparison pair for WIP status series metrics."""
primary: DashboardWipRangeSeries
comparison: DashboardWipRangeSeries
class DashboardKpis(SQLModel):
"""Topline dashboard KPI summary values."""
active_agents: int
tasks_in_progress: int
error_rate_pct: float
@@ -48,6 +64,8 @@ class DashboardKpis(SQLModel):
class DashboardMetrics(SQLModel):
"""Complete dashboard metrics response payload."""
range: Literal["24h", "7d"]
generated_at: datetime
kpis: DashboardKpis

View File

@@ -1,12 +1,16 @@
"""Schemas for organization, membership, and invite API payloads."""
from __future__ import annotations
from datetime import datetime
from uuid import UUID
from datetime import datetime # noqa: TCH003
from uuid import UUID # noqa: TCH003
from sqlmodel import Field, SQLModel
class OrganizationRead(SQLModel):
"""Organization payload returned by read endpoints."""
id: UUID
name: str
created_at: datetime
@@ -14,14 +18,20 @@ class OrganizationRead(SQLModel):
class OrganizationCreate(SQLModel):
"""Payload for creating a new organization."""
name: str
class OrganizationActiveUpdate(SQLModel):
"""Payload for switching the active organization context."""
organization_id: UUID
class OrganizationListItem(SQLModel):
"""Organization list row for current user memberships."""
id: UUID
name: str
role: str
@@ -29,6 +39,8 @@ class OrganizationListItem(SQLModel):
class OrganizationUserRead(SQLModel):
"""Embedded user fields included in organization member payloads."""
id: UUID
email: str | None = None
name: str | None = None
@@ -36,6 +48,8 @@ class OrganizationUserRead(SQLModel):
class OrganizationMemberRead(SQLModel):
"""Organization member payload including board-level access overrides."""
id: UUID
organization_id: UUID
user_id: UUID
@@ -49,16 +63,22 @@ class OrganizationMemberRead(SQLModel):
class OrganizationMemberUpdate(SQLModel):
"""Payload for partial updates to organization member role."""
role: str | None = None
class OrganizationBoardAccessSpec(SQLModel):
"""Board access specification used in member/invite mutation payloads."""
board_id: UUID
can_read: bool = True
can_write: bool = False
class OrganizationBoardAccessRead(SQLModel):
"""Board access payload returned from read endpoints."""
id: UUID
board_id: UUID
can_read: bool
@@ -68,12 +88,16 @@ class OrganizationBoardAccessRead(SQLModel):
class OrganizationMemberAccessUpdate(SQLModel):
"""Payload for replacing organization member access permissions."""
all_boards_read: bool = False
all_boards_write: bool = False
board_access: list[OrganizationBoardAccessSpec] = Field(default_factory=list)
class OrganizationInviteCreate(SQLModel):
"""Payload for creating an organization invite."""
invited_email: str
role: str = "member"
all_boards_read: bool = False
@@ -82,6 +106,8 @@ class OrganizationInviteCreate(SQLModel):
class OrganizationInviteRead(SQLModel):
"""Organization invite payload returned from read endpoints."""
id: UUID
organization_id: UUID
invited_email: str
@@ -97,4 +123,6 @@ class OrganizationInviteRead(SQLModel):
class OrganizationInviteAccept(SQLModel):
"""Payload for accepting an organization invite token."""
token: str

View File

@@ -1,3 +1,5 @@
"""Shared pagination response type aliases used by API routes."""
from __future__ import annotations
from typing import TypeVar

View File

@@ -1,9 +1,13 @@
"""Schemas for souls-directory search and markdown fetch responses."""
from __future__ import annotations
from pydantic import BaseModel
class SoulsDirectorySoulRef(BaseModel):
"""Reference metadata for a soul entry in the directory index."""
handle: str
slug: str
page_url: str
@@ -11,10 +15,14 @@ class SoulsDirectorySoulRef(BaseModel):
class SoulsDirectorySearchResponse(BaseModel):
"""Response wrapper for directory search results."""
items: list[SoulsDirectorySoulRef]
class SoulsDirectoryMarkdownResponse(BaseModel):
"""Response payload containing rendered markdown for a soul."""
handle: str
slug: str
content: str

View File

@@ -1,18 +1,23 @@
"""Schemas for task CRUD and task comment API payloads."""
from __future__ import annotations
from datetime import datetime
from typing import Any, Literal, Self
from uuid import UUID
from datetime import datetime # noqa: TCH003
from typing import Literal, Self
from uuid import UUID # noqa: TCH003
from pydantic import field_validator, model_validator
from sqlmodel import Field, SQLModel
from app.schemas.common import NonEmptyStr
from app.schemas.common import NonEmptyStr # noqa: TCH001
TaskStatus = Literal["inbox", "in_progress", "review", "done"]
STATUS_REQUIRED_ERROR = "status is required"
class TaskBase(SQLModel):
"""Shared task fields used by task create/read payloads."""
title: str
description: str | None = None
status: TaskStatus = "inbox"
@@ -23,10 +28,14 @@ class TaskBase(SQLModel):
class TaskCreate(TaskBase):
"""Payload for creating a task."""
created_by_user_id: UUID | None = None
class TaskUpdate(SQLModel):
"""Payload for partial task updates."""
title: str | None = None
description: str | None = None
status: TaskStatus | None = None
@@ -38,7 +47,8 @@ class TaskUpdate(SQLModel):
@field_validator("comment", mode="before")
@classmethod
def normalize_comment(cls, value: Any) -> Any:
def normalize_comment(cls, value: object) -> object | None:
"""Normalize blank comment strings to `None`."""
if value is None:
return None
if isinstance(value, str) and not value.strip():
@@ -47,12 +57,15 @@ class TaskUpdate(SQLModel):
@model_validator(mode="after")
def validate_status(self) -> Self:
"""Ensure explicitly supplied status is not null."""
if "status" in self.model_fields_set and self.status is None:
raise ValueError("status is required")
raise ValueError(STATUS_REQUIRED_ERROR)
return self
class TaskRead(TaskBase):
"""Task payload returned from read endpoints."""
id: UUID
board_id: UUID | None
created_by_user_id: UUID | None
@@ -64,10 +77,14 @@ class TaskRead(TaskBase):
class TaskCommentCreate(SQLModel):
"""Payload for creating a task comment."""
message: NonEmptyStr
class TaskCommentRead(SQLModel):
"""Task comment payload returned from read endpoints."""
id: UUID
message: str | None
agent_id: UUID | None

View File

@@ -1,11 +1,15 @@
"""User API schemas for create, update, and read operations."""
from __future__ import annotations
from uuid import UUID
from uuid import UUID # noqa: TCH003
from sqlmodel import SQLModel
class UserBase(SQLModel):
"""Common user profile fields shared across user payload schemas."""
clerk_user_id: str
email: str | None = None
name: str | None = None
@@ -17,10 +21,12 @@ class UserBase(SQLModel):
class UserCreate(UserBase):
pass
"""Payload used to create a user record."""
class UserUpdate(SQLModel):
"""Payload for partial user profile updates."""
name: str | None = None
preferred_name: str | None = None
pronouns: str | None = None
@@ -30,5 +36,7 @@ class UserUpdate(SQLModel):
class UserRead(UserBase):
"""Full user payload returned by API responses."""
id: UUID
is_super_admin: bool

View File

@@ -1,25 +1,31 @@
"""Composite read models assembled for board and board-group views."""
from __future__ import annotations
from datetime import datetime
from uuid import UUID
from datetime import datetime # noqa: TCH003
from uuid import UUID # noqa: TCH003
from sqlmodel import Field, SQLModel
from app.schemas.agents import AgentRead
from app.schemas.approvals import ApprovalRead
from app.schemas.board_groups import BoardGroupRead
from app.schemas.board_memory import BoardMemoryRead
from app.schemas.boards import BoardRead
from app.schemas.agents import AgentRead # noqa: TCH001
from app.schemas.approvals import ApprovalRead # noqa: TCH001
from app.schemas.board_groups import BoardGroupRead # noqa: TCH001
from app.schemas.board_memory import BoardMemoryRead # noqa: TCH001
from app.schemas.boards import BoardRead # noqa: TCH001
from app.schemas.tasks import TaskRead
class TaskCardRead(TaskRead):
"""Task read model enriched with assignee and approval counters."""
assignee: str | None = None
approvals_count: int = 0
approvals_pending_count: int = 0
class BoardSnapshot(SQLModel):
"""Aggregated board payload used by board snapshot endpoints."""
board: BoardRead
tasks: list[TaskCardRead]
agents: list[AgentRead]
@@ -29,6 +35,8 @@ class BoardSnapshot(SQLModel):
class BoardGroupTaskSummary(SQLModel):
"""Task summary row used inside board-group snapshot responses."""
id: UUID
board_id: UUID
board_name: str
@@ -44,11 +52,15 @@ class BoardGroupTaskSummary(SQLModel):
class BoardGroupBoardSnapshot(SQLModel):
"""Board-level rollup embedded within a board-group snapshot."""
board: BoardRead
task_counts: dict[str, int] = Field(default_factory=dict)
tasks: list[BoardGroupTaskSummary] = Field(default_factory=list)
class BoardGroupSnapshot(SQLModel):
"""Top-level board-group snapshot response payload."""
group: BoardGroupRead | None = None
boards: list[BoardGroupBoardSnapshot] = Field(default_factory=list)

View File

@@ -0,0 +1 @@
"""Business logic services for backend domain operations."""

View File

@@ -1,8 +1,13 @@
"""Utilities for recording normalized activity events."""
from __future__ import annotations
from uuid import UUID
from typing import TYPE_CHECKING
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.activity_events import ActivityEvent
@@ -15,6 +20,7 @@ def record_activity(
agent_id: UUID | None = None,
task_id: UUID | None = None,
) -> ActivityEvent:
"""Create and attach an activity event row to the current DB session."""
event = ActivityEvent(
event_type=event_type,
message=message,

View File

@@ -1,10 +1,16 @@
"""Access control helpers for admin-only operations."""
from __future__ import annotations
from typing import TYPE_CHECKING
from fastapi import HTTPException, status
from app.core.auth import AuthContext
if TYPE_CHECKING:
from app.core.auth import AuthContext
def require_admin(auth: AuthContext) -> None:
"""Raise HTTP 403 unless the authenticated actor is a user admin."""
if auth.actor_type != "user" or auth.user is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)

View File

@@ -1,21 +1,31 @@
"""Gateway-facing agent provisioning and cleanup helpers."""
# ruff: noqa: EM101, TRY003
from __future__ import annotations
import hashlib
import json
import re
from contextlib import suppress
from pathlib import Path
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast
from uuid import uuid4
from jinja2 import Environment, FileSystemLoader, StrictUndefined, select_autoescape
from app.core.config import settings
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, openclaw_call
from app.models.agents import Agent
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.users import User
from app.integrations.openclaw_gateway import (
OpenClawGatewayError,
ensure_session,
openclaw_call,
)
if TYPE_CHECKING:
from app.models.agents import Agent
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.users import User
DEFAULT_HEARTBEAT_CONFIG = {"every": "10m", "target": "none"}
DEFAULT_IDENTITY_PROFILE = {
@@ -35,7 +45,8 @@ EXTRA_IDENTITY_PROFILE_FIELDS = {
"verbosity": "identity_verbosity",
"output_format": "identity_output_format",
"update_cadence": "identity_update_cadence",
# Per-agent charter (optional). Used to give agents a "purpose in life" and a distinct vibe.
# Per-agent charter (optional).
# Used to give agents a "purpose in life" and a distinct vibe.
"purpose": "identity_purpose",
"personality": "identity_personality",
"custom_instructions": "identity_custom_instructions",
@@ -54,11 +65,11 @@ DEFAULT_GATEWAY_FILES = frozenset(
"BOOT.md",
"BOOTSTRAP.md",
"MEMORY.md",
}
},
)
# These files are intended to evolve within the agent workspace. Provision them if missing,
# but avoid overwriting existing content during updates.
# These files are intended to evolve within the agent workspace.
# Provision them if missing, but avoid overwriting existing content during updates.
#
# Examples:
# - SELF.md: evolving identity/preferences
@@ -68,6 +79,7 @@ PRESERVE_AGENT_EDITABLE_FILES = frozenset({"SELF.md", "USER.md", "MEMORY.md"})
HEARTBEAT_LEAD_TEMPLATE = "HEARTBEAT_LEAD.md"
HEARTBEAT_AGENT_TEMPLATE = "HEARTBEAT_AGENT.md"
_SESSION_KEY_PARTS_MIN = 2
MAIN_TEMPLATE_MAP = {
"AGENTS.md": "MAIN_AGENTS.md",
"HEARTBEAT.md": "MAIN_HEARTBEAT.md",
@@ -97,13 +109,13 @@ def _agent_id_from_session_key(session_key: str | None) -> str | None:
if not value.startswith("agent:"):
return None
parts = value.split(":")
if len(parts) < 2:
if len(parts) < _SESSION_KEY_PARTS_MIN:
return None
agent_id = parts[1].strip()
return agent_id or None
def _extract_agent_id(payload: object) -> str | None:
def _extract_agent_id(payload: object) -> str | None: # noqa: C901
def _from_list(items: object) -> str | None:
if not isinstance(items, list):
return None
@@ -137,7 +149,7 @@ def _agent_key(agent: Agent) -> str:
session_key = agent.openclaw_session_id or ""
if session_key.startswith("agent:"):
parts = session_key.split(":")
if len(parts) >= 2 and parts[1]:
if len(parts) >= _SESSION_KEY_PARTS_MIN and parts[1]:
return parts[1]
return _slugify(agent.name)
@@ -183,14 +195,14 @@ def _ensure_workspace_file(
if not workspace_path or not name:
return
# Only write to a dedicated, explicitly-configured local directory.
# Using `gateway.workspace_root` directly here is unsafe (and CodeQL correctly flags it)
# because it is a DB-backed config value.
# Using `gateway.workspace_root` directly here is unsafe.
# CodeQL correctly flags that value because it is DB-backed config.
base_root = (settings.local_agent_workspace_root or "").strip()
if not base_root:
return
base = Path(base_root).expanduser()
# Derive a stable, safe directory name from the (potentially untrusted) workspace path.
# Derive a stable, safe directory name from the untrusted workspace path.
# This prevents path traversal and avoids writing to arbitrary locations.
digest = hashlib.sha256(workspace_path.encode("utf-8")).hexdigest()[:16]
root = base / f"gateway-workspace-{digest}"
@@ -345,12 +357,14 @@ async def _supported_gateway_files(config: GatewayClientConfig) -> set[str]:
default_id = None
if isinstance(agents_payload, dict):
agents = list(agents_payload.get("agents") or [])
default_id = agents_payload.get("defaultId") or agents_payload.get("default_id")
default_id = agents_payload.get("defaultId") or agents_payload.get(
"default_id",
)
agent_id = default_id or (agents[0].get("id") if agents else None)
if not agent_id:
return set(DEFAULT_GATEWAY_FILES)
files_payload = await openclaw_call(
"agents.files.list", {"agentId": agent_id}, config=config
"agents.files.list", {"agentId": agent_id}, config=config,
)
if isinstance(files_payload, dict):
files = files_payload.get("files") or []
@@ -374,10 +388,12 @@ async def _reset_session(session_key: str, config: GatewayClientConfig) -> None:
async def _gateway_agent_files_index(
agent_id: str, config: GatewayClientConfig
agent_id: str, config: GatewayClientConfig,
) -> dict[str, dict[str, Any]]:
try:
payload = await openclaw_call("agents.files.list", {"agentId": agent_id}, config=config)
payload = await openclaw_call(
"agents.files.list", {"agentId": agent_id}, config=config,
)
if isinstance(payload, dict):
files = payload.get("files") or []
index: dict[str, dict[str, Any]] = {}
@@ -420,21 +436,25 @@ def _render_agent_files(
)
heartbeat_path = _templates_root() / heartbeat_template
if heartbeat_path.exists():
rendered[name] = env.get_template(heartbeat_template).render(**context).strip()
rendered[name] = (
env.get_template(heartbeat_template).render(**context).strip()
)
continue
override = overrides.get(name)
if override:
rendered[name] = env.from_string(override).render(**context).strip()
continue
template_name = (
template_overrides[name] if template_overrides and name in template_overrides else name
template_overrides[name]
if template_overrides and name in template_overrides
else name
)
path = _templates_root() / template_name
if path.exists():
rendered[name] = env.get_template(template_name).render(**context).strip()
continue
if name == "MEMORY.md":
# Back-compat fallback for existing gateways that don't ship a MEMORY.md template.
# Back-compat fallback for gateways that do not ship MEMORY.md.
rendered[name] = "# MEMORY\n\nBootstrap pending.\n"
continue
rendered[name] = ""
@@ -487,7 +507,9 @@ async def _patch_gateway_agent_list(
else:
new_list.append(entry)
if not updated:
new_list.append({"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat})
new_list.append(
{"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat},
)
patch = {"agents": {"list": new_list}}
params = {"raw": json.dumps(patch)}
@@ -496,7 +518,7 @@ async def _patch_gateway_agent_list(
await openclaw_call("config.patch", params, config=config)
async def patch_gateway_agent_heartbeats(
async def patch_gateway_agent_heartbeats( # noqa: C901
gateway: Gateway,
*,
entries: list[tuple[str, str, dict[str, Any]]],
@@ -521,7 +543,8 @@ async def patch_gateway_agent_heartbeats(
raise OpenClawGatewayError("config agents.list is not a list")
entry_by_id: dict[str, tuple[str, dict[str, Any]]] = {
agent_id: (workspace_path, heartbeat) for agent_id, workspace_path, heartbeat in entries
agent_id: (workspace_path, heartbeat)
for agent_id, workspace_path, heartbeat in entries
}
updated_ids: set[str] = set()
@@ -544,7 +567,9 @@ async def patch_gateway_agent_heartbeats(
for agent_id, (workspace_path, heartbeat) in entry_by_id.items():
if agent_id in updated_ids:
continue
new_list.append({"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat})
new_list.append(
{"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat},
)
patch = {"agents": {"list": new_list}}
params = {"raw": json.dumps(patch)}
@@ -585,7 +610,9 @@ async def _remove_gateway_agent_list(
raise OpenClawGatewayError("config agents.list is not a list")
new_list = [
entry for entry in lst if not (isinstance(entry, dict) and entry.get("id") == agent_id)
entry
for entry in lst
if not (isinstance(entry, dict) and entry.get("id") == agent_id)
]
if len(new_list) == len(lst):
return
@@ -616,7 +643,7 @@ async def _get_gateway_agent_entry(
return None
async def provision_agent(
async def provision_agent( # noqa: C901, PLR0912, PLR0913
agent: Agent,
board: Board,
gateway: Gateway,
@@ -627,6 +654,7 @@ async def provision_agent(
force_bootstrap: bool = False,
reset_session: bool = False,
) -> None:
"""Provision or update a regular board agent workspace."""
if not gateway.url:
return
if not gateway.workspace_root:
@@ -665,11 +693,9 @@ async def provision_agent(
content = rendered.get(name)
if not content:
continue
try:
_ensure_workspace_file(workspace_path, name, content, overwrite=False)
except OSError:
with suppress(OSError):
# Local workspace may not be writable/available; fall back to gateway API.
pass
_ensure_workspace_file(workspace_path, name, content, overwrite=False)
for name, content in rendered.items():
if content == "":
continue
@@ -694,7 +720,7 @@ async def provision_agent(
await _reset_session(session_key, client_config)
async def provision_main_agent(
async def provision_main_agent( # noqa: C901, PLR0912, PLR0913
agent: Agent,
gateway: Gateway,
auth_token: str,
@@ -704,12 +730,15 @@ async def provision_main_agent(
force_bootstrap: bool = False,
reset_session: bool = False,
) -> None:
"""Provision or update the gateway main agent workspace."""
if not gateway.url:
return
if not gateway.main_session_key:
raise ValueError("gateway main_session_key is required")
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
await ensure_session(gateway.main_session_key, config=client_config, label="Main Agent")
await ensure_session(
gateway.main_session_key, config=client_config, label="Main Agent",
)
agent_id = await _gateway_default_agent_id(
client_config,
@@ -763,6 +792,7 @@ async def cleanup_agent(
agent: Agent,
gateway: Gateway,
) -> str | None:
"""Remove an agent from gateway config and delete its session."""
if not gateway.url:
return None
if not gateway.workspace_root:

View File

@@ -1,30 +1,41 @@
"""Helpers for ensuring each board has a provisioned lead agent."""
from __future__ import annotations
from typing import Any
from typing import TYPE_CHECKING, Any
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.agent_tokens import generate_agent_token, hash_agent_token
from app.core.time import utcnow
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
from app.integrations.openclaw_gateway import (
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.agents import Agent
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.users import User
from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_agent
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.users import User
def lead_session_key(board: Board) -> str:
"""Return the deterministic main session key for a board lead agent."""
return f"agent:lead-{board.id}:main"
def lead_agent_name(_: Board) -> str:
"""Return the default display name for board lead agents."""
return "Lead Agent"
async def ensure_board_lead_agent(
async def ensure_board_lead_agent( # noqa: PLR0913
session: AsyncSession,
*,
board: Board,
@@ -35,11 +46,12 @@ async def ensure_board_lead_agent(
identity_profile: dict[str, str] | None = None,
action: str = "provision",
) -> tuple[Agent, bool]:
"""Ensure a board has a lead agent; return `(agent, created)`."""
existing = (
await session.exec(
select(Agent)
.where(Agent.board_id == board.id)
.where(col(Agent.is_board_lead).is_(True))
.where(col(Agent.is_board_lead).is_(True)),
)
).first()
if existing:
@@ -66,7 +78,11 @@ async def ensure_board_lead_agent(
}
if identity_profile:
merged_identity_profile.update(
{key: value.strip() for key, value in identity_profile.items() if value.strip()}
{
key: value.strip()
for key, value in identity_profile.items()
if value.strip()
},
)
agent = Agent(
@@ -89,11 +105,16 @@ async def ensure_board_lead_agent(
try:
await provision_agent(agent, board, gateway, raw_token, user, action=action)
if agent.openclaw_session_id:
await ensure_session(agent.openclaw_session_id, config=config, label=agent.name)
await ensure_session(
agent.openclaw_session_id,
config=config,
label=agent.name,
)
await send_message(
(
f"Hello {agent.name}. Your workspace has been provisioned.\n\n"
"Start the agent, run BOOT.md, and if BOOTSTRAP.md exists run it once "
"Start the agent, run BOOT.md, and if BOOTSTRAP.md exists run "
"it once "
"then delete it. Begin heartbeats after startup."
),
session_key=agent.openclaw_session_id,

View File

@@ -1,17 +1,17 @@
"""Helpers for assembling denormalized board snapshot response payloads."""
from __future__ import annotations
from datetime import timedelta
from uuid import UUID
from typing import TYPE_CHECKING
from sqlalchemy import case, func
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.time import utcnow
from app.models.agents import Agent
from app.models.approvals import Approval
from app.models.board_memory import BoardMemory
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.tasks import Task
from app.schemas.agents import AgentRead
@@ -25,6 +25,13 @@ from app.services.task_dependencies import (
dependency_status_by_id,
)
if TYPE_CHECKING:
from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.boards import Board
OFFLINE_AFTER = timedelta(minutes=10)
@@ -48,9 +55,15 @@ def _agent_to_read(agent: Agent, main_session_keys: set[str]) -> AgentRead:
model = AgentRead.model_validate(agent, from_attributes=True)
computed_status = _computed_agent_status(agent)
is_gateway_main = bool(
agent.openclaw_session_id and agent.openclaw_session_id in main_session_keys
agent.openclaw_session_id
and agent.openclaw_session_id in main_session_keys,
)
return model.model_copy(
update={
"status": computed_status,
"is_gateway_main": is_gateway_main,
},
)
return model.model_copy(update={"status": computed_status, "is_gateway_main": is_gateway_main})
def _memory_to_read(memory: BoardMemory) -> BoardMemoryRead:
@@ -72,7 +85,9 @@ def _task_to_card(
card = TaskCardRead.model_validate(task, from_attributes=True)
approvals_count, approvals_pending_count = counts_by_task_id.get(task.id, (0, 0))
assignee = (
agent_name_by_id.get(task.assigned_agent_id) if task.assigned_agent_id is not None else None
agent_name_by_id.get(task.assigned_agent_id)
if task.assigned_agent_id
else None
)
depends_on_task_ids = deps_by_task_id.get(task.id, [])
blocked_by_task_ids = blocked_by_dependency_ids(
@@ -89,21 +104,26 @@ def _task_to_card(
"depends_on_task_ids": depends_on_task_ids,
"blocked_by_task_ids": blocked_by_task_ids,
"is_blocked": bool(blocked_by_task_ids),
}
},
)
async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnapshot:
"""Build a board snapshot with tasks, agents, approvals, and chat history."""
board_read = BoardRead.model_validate(board, from_attributes=True)
tasks = list(
await Task.objects.filter_by(board_id=board.id)
.order_by(col(Task.created_at).desc())
.all(session)
.all(session),
)
task_ids = [task.id for task in tasks]
deps_by_task_id = await dependency_ids_by_task_id(session, board_id=board.id, task_ids=task_ids)
deps_by_task_id = await dependency_ids_by_task_id(
session,
board_id=board.id,
task_ids=task_ids,
)
all_dependency_ids: list[UUID] = []
for values in deps_by_task_id.values():
all_dependency_ids.extend(values)
@@ -127,9 +147,9 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
await session.exec(
select(func.count(col(Approval.id)))
.where(col(Approval.board_id) == board.id)
.where(col(Approval.status) == "pending")
)
).one()
.where(col(Approval.status) == "pending"),
),
).one(),
)
approvals = (
@@ -146,12 +166,14 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
select(
col(Approval.task_id),
func.count(col(Approval.id)).label("total"),
func.sum(case((col(Approval.status) == "pending", 1), else_=0)).label("pending"),
func.sum(
case((col(Approval.status) == "pending", 1), else_=0),
).label("pending"),
)
.where(col(Approval.board_id) == board.id)
.where(col(Approval.task_id).is_not(None))
.group_by(col(Approval.task_id))
)
.group_by(col(Approval.task_id)),
),
)
for task_id, total, pending in rows:
if task_id is None:

View File

@@ -1,26 +1,33 @@
"""Policy helpers for lead-agent approval and planning decisions."""
from __future__ import annotations
import hashlib
from typing import Mapping
CONFIDENCE_THRESHOLD = 80
MIN_PLANNING_SIGNALS = 2
def compute_confidence(rubric_scores: Mapping[str, int]) -> int:
"""Compute aggregate confidence from rubric score components."""
return int(sum(rubric_scores.values()))
def approval_required(*, confidence: int, is_external: bool, is_risky: bool) -> bool:
"""Return whether an action must go through explicit approval."""
return is_external or is_risky or confidence < CONFIDENCE_THRESHOLD
def infer_planning(signals: Mapping[str, bool]) -> bool:
"""Infer planning intent from boolean heuristic signals."""
# Require at least two planning signals to avoid spam on general boards.
truthy = [key for key, value in signals.items() if value]
return len(truthy) >= 2
return len(truthy) >= MIN_PLANNING_SIGNALS
def task_fingerprint(title: str, description: str | None, board_id: str) -> str:
"""Build a stable hash key for deduplicating similar board tasks."""
normalized_title = title.strip().lower()
normalized_desc = (description or "").strip().lower()
seed = f"{board_id}::{normalized_title}::{normalized_desc}"

View File

@@ -1,18 +1,24 @@
"""Helpers for extracting and matching `@mention` tokens in text."""
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from app.models.agents import Agent
if TYPE_CHECKING:
from app.models.agents import Agent
# Mention tokens are single, space-free words (e.g. "@alex", "@lead").
MENTION_PATTERN = re.compile(r"@([A-Za-z][\w-]{0,31})")
def extract_mentions(message: str) -> set[str]:
"""Extract normalized mention handles from a message body."""
return {match.group(1).lower() for match in MENTION_PATTERN.finditer(message)}
def matches_agent_mention(agent: Agent, mentions: set[str]) -> bool:
"""Return whether a mention set targets the provided agent."""
if not mentions:
return False

View File

@@ -1,14 +1,14 @@
"""Organization membership and board-access service helpers."""
# ruff: noqa: D101, D103
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable
from uuid import UUID
from typing import TYPE_CHECKING, Iterable
from fastapi import HTTPException, status
from sqlalchemy import func, or_
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.time import utcnow
from app.db import crud
@@ -19,7 +19,17 @@ from app.models.organization_invites import OrganizationInvite
from app.models.organization_members import OrganizationMember
from app.models.organizations import Organization
from app.models.users import User
from app.schemas.organizations import OrganizationBoardAccessSpec, OrganizationMemberAccessUpdate
if TYPE_CHECKING:
from uuid import UUID
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel.ext.asyncio.session import AsyncSession
from app.schemas.organizations import (
OrganizationBoardAccessSpec,
OrganizationMemberAccessUpdate,
)
DEFAULT_ORG_NAME = "Personal"
ADMIN_ROLES = {"owner", "admin"}
@@ -63,7 +73,9 @@ async def get_member(
).first(session)
async def get_first_membership(session: AsyncSession, user_id: UUID) -> OrganizationMember | None:
async def get_first_membership(
session: AsyncSession, user_id: UUID,
) -> OrganizationMember | None:
return (
await OrganizationMember.objects.filter_by(user_id=user_id)
.order_by(col(OrganizationMember.created_at).asc())
@@ -79,7 +91,9 @@ async def set_active_organization(
) -> OrganizationMember:
member = await get_member(session, user_id=user.id, organization_id=organization_id)
if member is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No org access")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="No org access",
)
if user.active_organization_id != organization_id:
user.active_organization_id = organization_id
session.add(user)
@@ -154,9 +168,10 @@ async def accept_invite(
access_rows = list(
await session.exec(
select(OrganizationInviteBoardAccess).where(
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
)
)
col(OrganizationInviteBoardAccess.organization_invite_id)
== invite.id,
),
),
)
for row in access_rows:
session.add(
@@ -167,7 +182,7 @@ async def accept_invite(
can_write=row.can_write,
created_at=now,
updated_at=now,
)
),
)
invite.accepted_by_user_id = user.id
@@ -182,7 +197,9 @@ async def accept_invite(
return member
async def ensure_member_for_user(session: AsyncSession, user: User) -> OrganizationMember:
async def ensure_member_for_user(
session: AsyncSession, user: User,
) -> OrganizationMember:
existing = await get_active_membership(session, user)
if existing is not None:
return existing
@@ -196,7 +213,9 @@ async def ensure_member_for_user(session: AsyncSession, user: User) -> Organizat
now = utcnow()
member_count = (
await session.exec(
select(func.count()).where(col(OrganizationMember.organization_id) == org.id)
select(func.count()).where(
col(OrganizationMember.organization_id) == org.id,
),
)
).one()
is_first = int(member_count or 0) == 0
@@ -257,30 +276,40 @@ async def require_board_access(
board: Board,
write: bool,
) -> OrganizationMember:
member = await get_member(session, user_id=user.id, organization_id=board.organization_id)
member = await get_member(
session, user_id=user.id, organization_id=board.organization_id,
)
if member is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No org access")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="No org access",
)
if not await has_board_access(session, member=member, board=board, write=write):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Board access denied")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Board access denied",
)
return member
def board_access_filter(member: OrganizationMember, *, write: bool) -> ColumnElement[bool]:
def board_access_filter(
member: OrganizationMember, *, write: bool,
) -> ColumnElement[bool]:
if write and member_all_boards_write(member):
return col(Board.organization_id) == member.organization_id
if not write and member_all_boards_read(member):
return col(Board.organization_id) == member.organization_id
access_stmt = select(OrganizationBoardAccess.board_id).where(
col(OrganizationBoardAccess.organization_member_id) == member.id
col(OrganizationBoardAccess.organization_member_id) == member.id,
)
if write:
access_stmt = access_stmt.where(col(OrganizationBoardAccess.can_write).is_(True))
access_stmt = access_stmt.where(
col(OrganizationBoardAccess.can_write).is_(True),
)
else:
access_stmt = access_stmt.where(
or_(
col(OrganizationBoardAccess.can_read).is_(True),
col(OrganizationBoardAccess.can_write).is_(True),
)
),
)
return col(Board.id).in_(access_stmt)
@@ -295,21 +324,25 @@ async def list_accessible_board_ids(
not write and member_all_boards_read(member)
):
ids = await session.exec(
select(Board.id).where(col(Board.organization_id) == member.organization_id)
select(Board.id).where(
col(Board.organization_id) == member.organization_id,
),
)
return list(ids)
access_stmt = select(OrganizationBoardAccess.board_id).where(
col(OrganizationBoardAccess.organization_member_id) == member.id
col(OrganizationBoardAccess.organization_member_id) == member.id,
)
if write:
access_stmt = access_stmt.where(col(OrganizationBoardAccess.can_write).is_(True))
access_stmt = access_stmt.where(
col(OrganizationBoardAccess.can_write).is_(True),
)
else:
access_stmt = access_stmt.where(
or_(
col(OrganizationBoardAccess.can_read).is_(True),
col(OrganizationBoardAccess.can_write).is_(True),
)
),
)
board_ids = await session.exec(access_stmt)
return list(board_ids)
@@ -337,18 +370,17 @@ async def apply_member_access_update(
if update.all_boards_read or update.all_boards_write:
return
rows: list[OrganizationBoardAccess] = []
for entry in update.board_access:
rows.append(
OrganizationBoardAccess(
organization_member_id=member.id,
board_id=entry.board_id,
can_read=entry.can_read,
can_write=entry.can_write,
created_at=now,
updated_at=now,
)
rows = [
OrganizationBoardAccess(
organization_member_id=member.id,
board_id=entry.board_id,
can_read=entry.can_read,
can_write=entry.can_write,
created_at=now,
updated_at=now,
)
for entry in update.board_access
]
session.add_all(rows)
@@ -367,18 +399,17 @@ async def apply_invite_board_access(
if invite.all_boards_read or invite.all_boards_write:
return
now = utcnow()
rows: list[OrganizationInviteBoardAccess] = []
for entry in entries:
rows.append(
OrganizationInviteBoardAccess(
organization_invite_id=invite.id,
board_id=entry.board_id,
can_read=entry.can_read,
can_write=entry.can_write,
created_at=now,
updated_at=now,
)
rows = [
OrganizationInviteBoardAccess(
organization_invite_id=invite.id,
board_id=entry.board_id,
can_read=entry.can_read,
can_write=entry.can_write,
created_at=now,
updated_at=now,
)
for entry in entries
]
session.add_all(rows)
@@ -423,9 +454,9 @@ async def apply_invite_to_member(
access_rows = list(
await session.exec(
select(OrganizationInviteBoardAccess).where(
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
)
)
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id,
),
),
)
for row in access_rows:
existing = (
@@ -433,7 +464,7 @@ async def apply_invite_to_member(
select(OrganizationBoardAccess).where(
col(OrganizationBoardAccess.organization_member_id) == member.id,
col(OrganizationBoardAccess.board_id) == row.board_id,
)
),
)
).first()
can_write = bool(row.can_write)
@@ -447,7 +478,7 @@ async def apply_invite_to_member(
can_write=can_write,
created_at=now,
updated_at=now,
)
),
)
else:
existing.can_read = bool(existing.can_read or can_read)

View File

@@ -1,3 +1,5 @@
"""Service helpers for querying and caching souls.directory content."""
from __future__ import annotations
import time
@@ -11,33 +13,41 @@ SOULS_DIRECTORY_BASE_URL: Final[str] = "https://souls.directory"
SOULS_DIRECTORY_SITEMAP_URL: Final[str] = f"{SOULS_DIRECTORY_BASE_URL}/sitemap.xml"
_SITEMAP_TTL_SECONDS: Final[int] = 60 * 60
_SOUL_URL_MIN_PARTS: Final[int] = 6
@dataclass(frozen=True, slots=True)
class SoulRef:
"""Handle/slug reference pair for a soul entry."""
handle: str
slug: str
@property
def page_url(self) -> str:
"""Return the canonical page URL for this soul."""
return f"{SOULS_DIRECTORY_BASE_URL}/souls/{self.handle}/{self.slug}"
@property
def raw_md_url(self) -> str:
"""Return the raw markdown URL for this soul."""
return f"{SOULS_DIRECTORY_BASE_URL}/api/souls/{self.handle}/{self.slug}.md"
def _parse_sitemap_soul_refs(sitemap_xml: str) -> list[SoulRef]:
"""Parse sitemap XML and extract valid souls.directory handle/slug refs."""
try:
root = ET.fromstring(sitemap_xml)
# Souls sitemap is fetched from a known trusted host in this service flow.
root = ET.fromstring(sitemap_xml) # noqa: S314
except ET.ParseError:
return []
# Handle both namespaced and non-namespaced sitemap XML.
urls: list[str] = []
for loc in root.iter():
if loc.tag.endswith("loc") and loc.text:
urls.append(loc.text.strip())
urls = [
loc.text.strip()
for loc in root.iter()
if loc.tag.endswith("loc") and loc.text
]
refs: list[SoulRef] = []
for url in urls:
@@ -45,7 +55,7 @@ def _parse_sitemap_soul_refs(sitemap_xml: str) -> list[SoulRef]:
continue
# Expected: https://souls.directory/souls/{handle}/{slug}
parts = url.split("/")
if len(parts) < 6:
if len(parts) < _SOUL_URL_MIN_PARTS:
continue
handle = parts[4].strip()
slug = parts[5].strip()
@@ -61,7 +71,11 @@ _sitemap_cache: dict[str, object] = {
}
async def list_souls_directory_refs(*, client: httpx.AsyncClient | None = None) -> list[SoulRef]:
async def list_souls_directory_refs(
*,
client: httpx.AsyncClient | None = None,
) -> list[SoulRef]:
"""Return cached sitemap-derived soul refs, refreshing when TTL expires."""
now = time.time()
loaded_raw = _sitemap_cache.get("loaded_at")
loaded_at = loaded_raw if isinstance(loaded_raw, (int, float)) else 0.0
@@ -93,11 +107,15 @@ async def fetch_soul_markdown(
slug: str,
client: httpx.AsyncClient | None = None,
) -> str:
"""Fetch raw markdown content for a specific handle/slug pair."""
normalized_handle = handle.strip().strip("/")
normalized_slug = slug.strip().strip("/")
if normalized_slug.endswith(".md"):
normalized_slug = normalized_slug[: -len(".md")]
url = f"{SOULS_DIRECTORY_BASE_URL}/api/souls/{normalized_handle}/{normalized_slug}.md"
url = (
f"{SOULS_DIRECTORY_BASE_URL}/api/souls/"
f"{normalized_handle}/{normalized_slug}.md"
)
owns_client = client is None
if client is None:
@@ -115,6 +133,7 @@ async def fetch_soul_markdown(
def search_souls(refs: list[SoulRef], *, query: str, limit: int = 20) -> list[SoulRef]:
"""Search refs by case-insensitive handle/slug substring with a hard limit."""
q = query.strip().lower()
if not q:
return refs[: max(0, min(limit, len(refs)))]

View File

@@ -0,0 +1 @@
"""Background worker tasks and queue processing utilities."""

View File

@@ -1,3 +1,5 @@
"""RQ queue and Redis connection helpers for background workers."""
from __future__ import annotations
from redis import Redis
@@ -7,8 +9,10 @@ from app.core.config import settings
def get_redis() -> Redis:
"""Create a Redis client from configured settings."""
return Redis.from_url(settings.redis_url)
def get_queue(name: str) -> Queue:
"""Return an RQ queue bound to the configured Redis connection."""
return Queue(name, connection=get_redis())