feat: add validation for minimum length on various fields and update type definitions

This commit is contained in:
Abhimanyu Saharan
2026-02-06 16:12:04 +05:30
parent ca614328ac
commit d86fe0a7a6
157 changed files with 12340 additions and 2977 deletions

View File

@@ -2,7 +2,8 @@ from __future__ import annotations
from fastapi import APIRouter, Depends, Query
from sqlalchemy import desc
from sqlmodel import Session, col, select
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import ActorContext, require_admin_or_agent
from app.db.session import get_session
@@ -13,14 +14,14 @@ router = APIRouter(prefix="/activity", tags=["activity"])
@router.get("", response_model=list[ActivityEventRead])
def list_activity(
async def list_activity(
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> list[ActivityEvent]:
statement = select(ActivityEvent)
if actor.actor_type == "agent" and actor.agent:
statement = statement.where(ActivityEvent.agent_id == actor.agent.id)
statement = statement.order_by(desc(col(ActivityEvent.created_at))).offset(offset).limit(limit)
return list(session.exec(statement))
return list(await session.exec(statement))

View File

@@ -1,10 +1,10 @@
from __future__ import annotations
import asyncio
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel import Session, select
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api import agents as agents_api
from app.api import approvals as approvals_api
@@ -16,15 +16,20 @@ from app.core.agent_auth import AgentAuthContext, get_agent_auth_context
from app.db.session import get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
from app.models.activity_events import ActivityEvent
from app.models.agents import Agent
from app.models.approvals import Approval
from app.models.board_memory import BoardMemory
from app.models.board_onboarding import BoardOnboardingSession
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.tasks import Task
from app.schemas.agents import AgentCreate, AgentHeartbeat, AgentHeartbeatCreate, AgentNudge, AgentRead
from app.schemas.approvals import ApprovalCreate, ApprovalRead
from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalStatus
from app.schemas.board_memory import BoardMemoryCreate, BoardMemoryRead
from app.schemas.board_onboarding import BoardOnboardingRead
from app.schemas.board_onboarding import BoardOnboardingAgentUpdate, BoardOnboardingRead
from app.schemas.boards import BoardRead
from app.schemas.common import OkResponse
from app.schemas.tasks import TaskCommentCreate, TaskCommentRead, TaskCreate, TaskRead, TaskUpdate
from app.services.activity_log import record_activity
@@ -40,24 +45,24 @@ def _guard_board_access(agent_ctx: AgentAuthContext, board: Board) -> None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
def _gateway_config(session: Session, board: Board) -> GatewayClientConfig:
async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig:
if not board.gateway_id:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
gateway = session.get(Gateway, board.gateway_id)
gateway = await session.get(Gateway, board.gateway_id)
if gateway is None or not gateway.url:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
return GatewayClientConfig(url=gateway.url, token=gateway.token)
@router.get("/boards", response_model=list[BoardRead])
def list_boards(
session: Session = Depends(get_session),
async def list_boards(
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
) -> list[Board]:
if agent_ctx.agent.board_id:
board = session.get(Board, agent_ctx.agent.board_id)
board = await session.get(Board, agent_ctx.agent.board_id)
return [board] if board else []
return list(session.exec(select(Board)))
return list(await session.exec(select(Board)))
@router.get("/boards/{board_id}", response_model=BoardRead)
@@ -70,10 +75,10 @@ def get_board(
@router.get("/agents", response_model=list[AgentRead])
def list_agents(
async def list_agents(
board_id: UUID | None = Query(default=None),
limit: int | None = Query(default=None, ge=1, le=200),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
) -> list[AgentRead]:
statement = select(Agent)
@@ -85,8 +90,8 @@ def list_agents(
statement = statement.where(Agent.board_id == board_id)
if limit is not None:
statement = statement.limit(limit)
agents = list(session.exec(statement))
main_session_keys = agents_api._get_gateway_main_session_keys(session)
agents = list(await session.exec(statement))
main_session_keys = await agents_api._get_gateway_main_session_keys(session)
return [
agents_api._to_agent_read(agents_api._with_computed_status(agent), main_session_keys)
for agent in agents
@@ -94,17 +99,17 @@ def list_agents(
@router.get("/boards/{board_id}/tasks", response_model=list[TaskRead])
def list_tasks(
async def list_tasks(
status_filter: str | None = Query(default=None, alias="status"),
assigned_agent_id: UUID | None = None,
unassigned: bool | None = None,
limit: int | None = Query(default=None, ge=1, le=200),
board: Board = Depends(get_board_or_404),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
) -> list[TaskRead]:
) -> list[Task]:
_guard_board_access(agent_ctx, board)
return tasks_api.list_tasks(
return await tasks_api.list_tasks(
status_filter=status_filter,
assigned_agent_id=assigned_agent_id,
unassigned=unassigned,
@@ -116,22 +121,21 @@ def list_tasks(
@router.post("/boards/{board_id}/tasks", response_model=TaskRead)
def create_task(
async def create_task(
payload: TaskCreate,
board: Board = Depends(get_board_or_404),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
) -> TaskRead:
) -> Task:
_guard_board_access(agent_ctx, board)
if not agent_ctx.agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
tasks_api.validate_task_status(payload.status)
task = Task.model_validate(payload)
task.board_id = board.id
task.auto_created = True
task.auto_reason = f"lead_agent:{agent_ctx.agent.id}"
if task.assigned_agent_id:
agent = session.get(Agent, task.assigned_agent_id)
agent = await session.get(Agent, task.assigned_agent_id)
if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if agent.is_board_lead:
@@ -142,8 +146,8 @@ def create_task(
if agent.board_id and agent.board_id != board.id:
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
session.add(task)
session.commit()
session.refresh(task)
await session.commit()
await session.refresh(task)
record_activity(
session,
event_type="task.created",
@@ -151,11 +155,11 @@ def create_task(
message=f"Task created by lead: {task.title}.",
agent_id=agent_ctx.agent.id,
)
session.commit()
await session.commit()
if task.assigned_agent_id:
assigned_agent = session.get(Agent, task.assigned_agent_id)
assigned_agent = await session.get(Agent, task.assigned_agent_id)
if assigned_agent:
tasks_api._notify_agent_on_task_assign(
await tasks_api._notify_agent_on_task_assign(
session=session,
board=board,
task=task,
@@ -165,15 +169,15 @@ def create_task(
@router.patch("/boards/{board_id}/tasks/{task_id}", response_model=TaskRead)
def update_task(
async def update_task(
payload: TaskUpdate,
task=Depends(get_task_or_404),
session: Session = Depends(get_session),
task: Task = Depends(get_task_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
) -> TaskRead:
) -> Task:
if agent_ctx.agent.board_id and task.board_id and agent_ctx.agent.board_id != task.board_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return tasks_api.update_task(
return await tasks_api.update_task(
payload=payload,
task=task,
session=session,
@@ -182,14 +186,14 @@ def update_task(
@router.get("/boards/{board_id}/tasks/{task_id}/comments", response_model=list[TaskCommentRead])
def list_task_comments(
task=Depends(get_task_or_404),
session: Session = Depends(get_session),
async def list_task_comments(
task: Task = Depends(get_task_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
) -> list[TaskCommentRead]:
) -> list[ActivityEvent]:
if agent_ctx.agent.board_id and task.board_id and agent_ctx.agent.board_id != task.board_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return tasks_api.list_task_comments(
return await tasks_api.list_task_comments(
task=task,
session=session,
actor=_actor(agent_ctx),
@@ -197,15 +201,15 @@ def list_task_comments(
@router.post("/boards/{board_id}/tasks/{task_id}/comments", response_model=TaskCommentRead)
def create_task_comment(
async def create_task_comment(
payload: TaskCommentCreate,
task=Depends(get_task_or_404),
session: Session = Depends(get_session),
task: Task = Depends(get_task_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
) -> TaskCommentRead:
) -> ActivityEvent:
if agent_ctx.agent.board_id and task.board_id and agent_ctx.agent.board_id != task.board_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return tasks_api.create_task_comment(
return await tasks_api.create_task_comment(
payload=payload,
task=task,
session=session,
@@ -214,15 +218,15 @@ def create_task_comment(
@router.get("/boards/{board_id}/memory", response_model=list[BoardMemoryRead])
def list_board_memory(
async def list_board_memory(
limit: int = Query(default=50, ge=1, le=200),
offset: int = Query(default=0, ge=0),
board=Depends(get_board_or_404),
session: Session = Depends(get_session),
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
) -> list[BoardMemoryRead]:
) -> list[BoardMemory]:
_guard_board_access(agent_ctx, board)
return board_memory_api.list_board_memory(
return await board_memory_api.list_board_memory(
limit=limit,
offset=offset,
board=board,
@@ -232,14 +236,14 @@ def list_board_memory(
@router.post("/boards/{board_id}/memory", response_model=BoardMemoryRead)
def create_board_memory(
async def create_board_memory(
payload: BoardMemoryCreate,
board=Depends(get_board_or_404),
session: Session = Depends(get_session),
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
) -> BoardMemoryRead:
) -> BoardMemory:
_guard_board_access(agent_ctx, board)
return board_memory_api.create_board_memory(
return await board_memory_api.create_board_memory(
payload=payload,
board=board,
session=session,
@@ -248,14 +252,14 @@ def create_board_memory(
@router.get("/boards/{board_id}/approvals", response_model=list[ApprovalRead])
def list_approvals(
status_filter: str | None = Query(default=None, alias="status"),
board=Depends(get_board_or_404),
session: Session = Depends(get_session),
async def list_approvals(
status_filter: ApprovalStatus | None = Query(default=None, alias="status"),
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
) -> list[ApprovalRead]:
) -> list[Approval]:
_guard_board_access(agent_ctx, board)
return approvals_api.list_approvals(
return await approvals_api.list_approvals(
status_filter=status_filter,
board=board,
session=session,
@@ -264,14 +268,14 @@ def list_approvals(
@router.post("/boards/{board_id}/approvals", response_model=ApprovalRead)
def create_approval(
async def create_approval(
payload: ApprovalCreate,
board=Depends(get_board_or_404),
session: Session = Depends(get_session),
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
) -> ApprovalRead:
) -> Approval:
_guard_board_access(agent_ctx, board)
return approvals_api.create_approval(
return await approvals_api.create_approval(
payload=payload,
board=board,
session=session,
@@ -280,14 +284,14 @@ def create_approval(
@router.post("/boards/{board_id}/onboarding", response_model=BoardOnboardingRead)
def update_onboarding(
payload: dict[str, object],
async def update_onboarding(
payload: BoardOnboardingAgentUpdate,
board: Board = Depends(get_board_or_404),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
) -> BoardOnboardingRead:
) -> BoardOnboardingSession:
_guard_board_access(agent_ctx, board)
return onboarding_api.agent_onboarding_update(
return await onboarding_api.agent_onboarding_update(
payload=payload,
board=board,
session=session,
@@ -298,7 +302,7 @@ def update_onboarding(
@router.post("/agents", response_model=AgentRead)
async def create_agent(
payload: AgentCreate,
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
) -> AgentRead:
if not agent_ctx.agent.is_board_lead:
@@ -313,18 +317,18 @@ async def create_agent(
)
@router.post("/boards/{board_id}/agents/{agent_id}/nudge")
def nudge_agent(
@router.post("/boards/{board_id}/agents/{agent_id}/nudge", response_model=OkResponse)
async def nudge_agent(
payload: AgentNudge,
agent_id: str,
board: Board = Depends(get_board_or_404),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
) -> dict[str, bool]:
) -> OkResponse:
_guard_board_access(agent_ctx, board)
if not agent_ctx.agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
target = session.get(Agent, agent_id)
target = await session.get(Agent, agent_id)
if target is None or (target.board_id and target.board_id != board.id):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if not target.openclaw_session_id:
@@ -332,15 +336,9 @@ def nudge_agent(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Target agent has no session key",
)
message = payload.message.strip()
if not message:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="message is required",
)
config = _gateway_config(session, board)
async def _send() -> None:
message = payload.message
config = await _gateway_config(session, board)
try:
await ensure_session(target.openclaw_session_id, config=config, label=target.name)
await send_message(
message,
@@ -348,9 +346,6 @@ def nudge_agent(
config=config,
deliver=True,
)
try:
asyncio.run(_send())
except OpenClawGatewayError as exc:
record_activity(
session,
@@ -358,7 +353,7 @@ def nudge_agent(
message=f"Nudge failed for {target.name}: {exc}",
agent_id=agent_ctx.agent.id,
)
session.commit()
await session.commit()
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
record_activity(
session,
@@ -366,18 +361,18 @@ def nudge_agent(
message=f"Nudge sent to {target.name}.",
agent_id=agent_ctx.agent.id,
)
session.commit()
return {"ok": True}
await session.commit()
return OkResponse()
@router.post("/heartbeat", response_model=AgentRead)
async def agent_heartbeat(
payload: AgentHeartbeatCreate,
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context),
) -> AgentRead:
# Heartbeats must apply to the authenticated agent; agent names are not unique.
return agents_api.heartbeat_agent( # type: ignore[attr-defined]
return await agents_api.heartbeat_agent(
agent_id=str(agent_ctx.agent.id),
payload=AgentHeartbeat(status=payload.status),
session=session,

View File

@@ -3,19 +3,21 @@ from __future__ import annotations
import asyncio
import json
import re
from collections.abc import AsyncIterator
from datetime import datetime, timedelta, timezone
from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import asc, or_, update
from sqlmodel import Session, col, select
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sse_starlette.sse import EventSourceResponse
from starlette.concurrency import run_in_threadpool
from app.api.deps import ActorContext, require_admin_auth, require_admin_or_agent
from app.core.agent_tokens import generate_agent_token, hash_agent_token
from app.core.auth import AuthContext
from app.db.session import engine, get_session
from app.core.time import utcnow
from app.db.session import async_session_maker, get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
from app.models.activity_events import ActivityEvent
@@ -23,6 +25,7 @@ from app.models.agents import Agent
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.tasks import Task
from app.schemas.common import OkResponse
from app.schemas.agents import (
AgentCreate,
AgentHeartbeat,
@@ -60,27 +63,6 @@ def _parse_since(value: str | None) -> datetime | None:
return parsed
def _normalize_identity_profile(
profile: dict[str, object] | None,
) -> dict[str, str] | None:
if not profile:
return None
normalized: dict[str, str] = {}
for key, raw in profile.items():
if raw is None:
continue
if isinstance(raw, list):
parts = [str(item).strip() for item in raw if str(item).strip()]
if not parts:
continue
normalized[key] = ", ".join(parts)
continue
value = str(raw).strip()
if value:
normalized[key] = value
return normalized or None
def _slugify(value: str) -> str:
slug = re.sub(r"[^a-z0-9]+", "-", value.lower()).strip("-")
return slug or uuid4().hex
@@ -100,25 +82,25 @@ def _workspace_path(agent_name: str, workspace_root: str | None) -> str:
return f"{root}/workspace-{_slugify(agent_name)}"
def _require_board(session: Session, board_id: UUID | str | None) -> Board:
async def _require_board(session: AsyncSession, board_id: UUID | str | None) -> Board:
if not board_id:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="board_id is required",
)
board = session.get(Board, board_id)
board = await session.get(Board, board_id)
if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found")
return board
def _require_gateway(session: Session, board: Board) -> tuple[Gateway, GatewayClientConfig]:
async def _require_gateway(session: AsyncSession, board: Board) -> tuple[Gateway, GatewayClientConfig]:
if not board.gateway_id:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Board gateway_id is required",
)
gateway = session.get(Gateway, board.gateway_id)
gateway = await session.get(Gateway, board.gateway_id)
if gateway is None:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -151,8 +133,8 @@ def _gateway_client_config(gateway: Gateway) -> GatewayClientConfig:
return GatewayClientConfig(url=gateway.url, token=gateway.token)
def _get_gateway_main_session_keys(session: Session) -> set[str]:
keys = session.exec(select(Gateway.main_session_key)).all()
async def _get_gateway_main_session_keys(session: AsyncSession) -> set[str]:
keys = (await session.exec(select(Gateway.main_session_key))).all()
return {key for key in keys if key}
@@ -165,10 +147,12 @@ def _to_agent_read(agent: Agent, main_session_keys: set[str]) -> AgentRead:
return model.model_copy(update={"is_gateway_main": _is_gateway_main(agent, main_session_keys)})
def _find_gateway_for_main_session(session: Session, session_key: str | None) -> Gateway | None:
async def _find_gateway_for_main_session(
session: AsyncSession, session_key: str | None
) -> Gateway | None:
if not session_key:
return None
return session.exec(select(Gateway).where(Gateway.main_session_key == session_key)).first()
return (await session.exec(select(Gateway).where(Gateway.main_session_key == session_key))).first()
async def _ensure_gateway_session(
@@ -184,7 +168,7 @@ async def _ensure_gateway_session(
def _with_computed_status(agent: Agent) -> Agent:
now = datetime.utcnow()
now = utcnow()
if agent.status in {"deleting", "updating"}:
return agent
if agent.last_seen_at is None:
@@ -198,24 +182,24 @@ def _serialize_agent(agent: Agent, main_session_keys: set[str]) -> dict[str, obj
return _to_agent_read(_with_computed_status(agent), main_session_keys).model_dump(mode="json")
def _fetch_agent_events(
async def _fetch_agent_events(
session: AsyncSession,
board_id: UUID | None,
since: datetime,
) -> list[Agent]:
with Session(engine) as session:
statement = select(Agent)
if board_id:
statement = statement.where(col(Agent.board_id) == board_id)
statement = statement.where(
or_(
col(Agent.updated_at) >= since,
col(Agent.last_seen_at) >= since,
)
).order_by(asc(col(Agent.updated_at)))
return list(session.exec(statement))
statement = select(Agent)
if board_id:
statement = statement.where(col(Agent.board_id) == board_id)
statement = statement.where(
or_(
col(Agent.updated_at) >= since,
col(Agent.last_seen_at) >= since,
)
).order_by(asc(col(Agent.updated_at)))
return list(await session.exec(statement))
def _record_heartbeat(session: Session, agent: Agent) -> None:
def _record_heartbeat(session: AsyncSession, agent: Agent) -> None:
record_activity(
session,
event_type="agent.heartbeat",
@@ -224,7 +208,7 @@ def _record_heartbeat(session: Session, agent: Agent) -> None:
)
def _record_instruction_failure(session: Session, agent: Agent, error: str, action: str) -> None:
def _record_instruction_failure(session: AsyncSession, agent: Agent, error: str, action: str) -> None:
action_label = action.replace("_", " ").capitalize()
record_activity(
session,
@@ -248,12 +232,12 @@ async def _send_wakeup_message(
@router.get("", response_model=list[AgentRead])
def list_agents(
session: Session = Depends(get_session),
async def list_agents(
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(require_admin_auth),
) -> list[Agent]:
agents = list(session.exec(select(Agent)))
main_session_keys = _get_gateway_main_session_keys(session)
) -> list[AgentRead]:
agents = list(await session.exec(select(Agent)))
main_session_keys = await _get_gateway_main_session_keys(session)
return [_to_agent_read(_with_computed_status(agent), main_session_keys) for agent in agents]
@@ -264,24 +248,23 @@ async def stream_agents(
since: str | None = Query(default=None),
auth: AuthContext = Depends(require_admin_auth),
) -> EventSourceResponse:
since_dt = _parse_since(since) or datetime.utcnow()
since_dt = _parse_since(since) or utcnow()
last_seen = since_dt
async def event_generator():
async def event_generator() -> AsyncIterator[dict[str, str]]:
nonlocal last_seen
while True:
if await request.is_disconnected():
break
agents = await run_in_threadpool(_fetch_agent_events, board_id, last_seen)
if agents:
with Session(engine) as session:
main_session_keys = _get_gateway_main_session_keys(session)
for agent in agents:
updated_at = agent.updated_at or agent.last_seen_at or datetime.utcnow()
if updated_at > last_seen:
last_seen = updated_at
payload = {"agent": _serialize_agent(agent, main_session_keys)}
yield {"event": "agent", "data": json.dumps(payload)}
async with async_session_maker() as session:
agents = await _fetch_agent_events(session, board_id, last_seen)
main_session_keys = await _get_gateway_main_session_keys(session) if agents else set()
for agent in agents:
updated_at = agent.updated_at or agent.last_seen_at or utcnow()
if updated_at > last_seen:
last_seen = updated_at
payload = {"agent": _serialize_agent(agent, main_session_keys)}
yield {"event": "agent", "data": json.dumps(payload)}
await asyncio.sleep(2)
return EventSourceResponse(event_generator(), ping=15)
@@ -290,9 +273,9 @@ async def stream_agents(
@router.post("", response_model=AgentRead)
async def create_agent(
payload: AgentCreate,
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> Agent:
) -> AgentRead:
if actor.actor_type == "agent":
if not actor.agent or not actor.agent.is_board_lead:
raise HTTPException(
@@ -311,39 +294,36 @@ async def create_agent(
)
payload = AgentCreate(**{**payload.model_dump(), "board_id": actor.agent.board_id})
board = _require_board(session, payload.board_id)
gateway, client_config = _require_gateway(session, board)
board = await _require_board(session, payload.board_id)
gateway, client_config = await _require_gateway(session, board)
data = payload.model_dump()
requested_name = (data.get("name") or "").strip()
if requested_name:
existing = session.exec(
select(Agent)
.where(Agent.board_id == board.id)
.where(col(Agent.name).ilike(requested_name))
existing = (
await session.exec(
select(Agent)
.where(Agent.board_id == board.id)
.where(col(Agent.name).ilike(requested_name))
)
).first()
if existing:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="An agent with this name already exists on this board.",
)
if data.get("identity_template") == "":
data["identity_template"] = None
if data.get("soul_template") == "":
data["soul_template"] = None
data["identity_profile"] = _normalize_identity_profile(data.get("identity_profile"))
agent = Agent.model_validate(data)
agent.status = "provisioning"
raw_token = generate_agent_token()
agent.agent_token_hash = hash_agent_token(raw_token)
if agent.heartbeat_config is None:
agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy()
agent.provision_requested_at = datetime.utcnow()
agent.provision_requested_at = utcnow()
agent.provision_action = "provision"
session_key, session_error = await _ensure_gateway_session(agent.name, client_config)
agent.openclaw_session_id = session_key
session.add(agent)
session.commit()
session.refresh(agent)
await session.commit()
await session.refresh(agent)
if session_error:
record_activity(
session,
@@ -358,7 +338,7 @@ async def create_agent(
message=f"Session created for {agent.name}.",
agent_id=agent.id,
)
session.commit()
await session.commit()
try:
await provision_agent(
agent,
@@ -372,9 +352,9 @@ async def create_agent(
agent.provision_confirm_token_hash = None
agent.provision_requested_at = None
agent.provision_action = None
agent.updated_at = datetime.utcnow()
agent.updated_at = utcnow()
session.add(agent)
session.commit()
await session.commit()
record_activity(
session,
event_type="agent.provision",
@@ -387,26 +367,27 @@ async def create_agent(
message=f"Wakeup message sent to {agent.name}.",
agent_id=agent.id,
)
session.commit()
await session.commit()
except OpenClawGatewayError as exc:
_record_instruction_failure(session, agent, str(exc), "provision")
session.commit()
await session.commit()
except Exception as exc: # pragma: no cover - unexpected provisioning errors
_record_instruction_failure(session, agent, str(exc), "provision")
session.commit()
return agent
await session.commit()
main_session_keys = await _get_gateway_main_session_keys(session)
return _to_agent_read(_with_computed_status(agent), main_session_keys)
@router.get("/{agent_id}", response_model=AgentRead)
def get_agent(
async def get_agent(
agent_id: str,
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(require_admin_auth),
) -> Agent:
agent = session.get(Agent, agent_id)
) -> AgentRead:
agent = await session.get(Agent, agent_id)
if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
main_session_keys = _get_gateway_main_session_keys(session)
main_session_keys = await _get_gateway_main_session_keys(session)
return _to_agent_read(_with_computed_status(agent), main_session_keys)
@@ -415,10 +396,10 @@ async def update_agent(
agent_id: str,
payload: AgentUpdate,
force: bool = False,
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(require_admin_auth),
) -> Agent:
agent = session.get(Agent, agent_id)
) -> AgentRead:
agent = await session.get(Agent, agent_id)
if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
updates = payload.model_dump(exclude_unset=True)
@@ -428,21 +409,15 @@ async def update_agent(
status_code=status.HTTP_403_FORBIDDEN,
detail="status is controlled by agent heartbeat",
)
if updates.get("identity_template") == "":
updates["identity_template"] = None
if updates.get("soul_template") == "":
updates["soul_template"] = None
if "identity_profile" in updates:
updates["identity_profile"] = _normalize_identity_profile(updates.get("identity_profile"))
if not updates and not force and make_main is None:
main_session_keys = _get_gateway_main_session_keys(session)
main_session_keys = await _get_gateway_main_session_keys(session)
return _to_agent_read(_with_computed_status(agent), main_session_keys)
main_gateway = _find_gateway_for_main_session(session, agent.openclaw_session_id)
main_gateway = await _find_gateway_for_main_session(session, agent.openclaw_session_id)
gateway_for_main: Gateway | None = None
if make_main is True:
board_source = updates.get("board_id") or agent.board_id
board_for_main = _require_board(session, board_source)
gateway_for_main, _ = _require_gateway(session, board_for_main)
board_for_main = await _require_board(session, board_source)
gateway_for_main, _ = await _require_gateway(session, board_for_main)
updates["board_id"] = None
agent.is_board_lead = False
agent.openclaw_session_id = gateway_for_main.main_session_key
@@ -450,18 +425,18 @@ async def update_agent(
elif make_main is False:
agent.openclaw_session_id = None
if make_main is not True and "board_id" in updates:
_require_board(session, updates["board_id"])
await _require_board(session, updates["board_id"])
for key, value in updates.items():
setattr(agent, key, value)
if make_main is None and main_gateway is not None:
agent.board_id = None
agent.is_board_lead = False
agent.updated_at = datetime.utcnow()
agent.updated_at = utcnow()
if agent.heartbeat_config is None:
agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy()
session.add(agent)
session.commit()
session.refresh(agent)
await session.commit()
await session.refresh(agent)
is_main_agent = False
board: Board | None = None
gateway: Gateway | None = None
@@ -490,8 +465,8 @@ async def update_agent(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="board_id is required for non-main agents",
)
board = _require_board(session, agent.board_id)
gateway, client_config = _require_gateway(session, board)
board = await _require_board(session, agent.board_id)
gateway, client_config = await _require_gateway(session, board)
session_key = agent.openclaw_session_id or _build_session_key(agent.name)
try:
if client_config is None:
@@ -503,19 +478,19 @@ async def update_agent(
if not agent.openclaw_session_id:
agent.openclaw_session_id = session_key
session.add(agent)
session.commit()
session.refresh(agent)
await session.commit()
await session.refresh(agent)
except OpenClawGatewayError as exc:
_record_instruction_failure(session, agent, str(exc), "update")
session.commit()
await session.commit()
raw_token = generate_agent_token()
agent.agent_token_hash = hash_agent_token(raw_token)
agent.provision_requested_at = datetime.utcnow()
agent.provision_requested_at = utcnow()
agent.provision_action = "update"
agent.status = "updating"
session.add(agent)
session.commit()
session.refresh(agent)
await session.commit()
await session.refresh(agent)
try:
if gateway is None:
raise HTTPException(
@@ -533,6 +508,11 @@ async def update_agent(
reset_session=True,
)
else:
if board is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="board is required for non-main agent provisioning",
)
await provision_agent(
agent,
board,
@@ -548,9 +528,9 @@ async def update_agent(
agent.provision_requested_at = None
agent.provision_action = None
agent.status = "online"
agent.updated_at = datetime.utcnow()
agent.updated_at = utcnow()
session.add(agent)
session.commit()
await session.commit()
record_activity(
session,
event_type="agent.update.direct",
@@ -563,33 +543,33 @@ async def update_agent(
message=f"Wakeup message sent to {agent.name}.",
agent_id=agent.id,
)
session.commit()
await session.commit()
except OpenClawGatewayError as exc:
_record_instruction_failure(session, agent, str(exc), "update")
session.commit()
await session.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Gateway update failed: {exc}",
) from exc
except Exception as exc: # pragma: no cover - unexpected provisioning errors
_record_instruction_failure(session, agent, str(exc), "update")
session.commit()
await session.commit()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Unexpected error updating agent provisioning.",
) from exc
main_session_keys = _get_gateway_main_session_keys(session)
main_session_keys = await _get_gateway_main_session_keys(session)
return _to_agent_read(_with_computed_status(agent), main_session_keys)
@router.post("/{agent_id}/heartbeat", response_model=AgentRead)
def heartbeat_agent(
async def heartbeat_agent(
agent_id: str,
payload: AgentHeartbeat,
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> AgentRead:
agent = session.get(Agent, agent_id)
agent = await session.get(Agent, agent_id)
if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if actor.actor_type == "agent" and actor.agent and actor.agent.id != agent.id:
@@ -598,25 +578,25 @@ def heartbeat_agent(
agent.status = payload.status
elif agent.status == "provisioning":
agent.status = "online"
agent.last_seen_at = datetime.utcnow()
agent.updated_at = datetime.utcnow()
agent.last_seen_at = utcnow()
agent.updated_at = utcnow()
_record_heartbeat(session, agent)
session.add(agent)
session.commit()
session.refresh(agent)
main_session_keys = _get_gateway_main_session_keys(session)
await session.commit()
await session.refresh(agent)
main_session_keys = await _get_gateway_main_session_keys(session)
return _to_agent_read(_with_computed_status(agent), main_session_keys)
@router.post("/heartbeat", response_model=AgentRead)
async def heartbeat_or_create_agent(
payload: AgentHeartbeatCreate,
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> AgentRead:
# Agent tokens must heartbeat their authenticated agent record. Names are not unique.
if actor.actor_type == "agent" and actor.agent:
return heartbeat_agent(
return await heartbeat_agent(
agent_id=str(actor.agent.id),
payload=AgentHeartbeat(status=payload.status),
session=session,
@@ -626,12 +606,12 @@ async def heartbeat_or_create_agent(
statement = select(Agent).where(Agent.name == payload.name)
if payload.board_id is not None:
statement = statement.where(Agent.board_id == payload.board_id)
agent = session.exec(statement).first()
agent = (await session.exec(statement)).first()
if agent is None:
if actor.actor_type == "agent":
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
board = _require_board(session, payload.board_id)
gateway, client_config = _require_gateway(session, board)
board = await _require_board(session, payload.board_id)
gateway, client_config = await _require_gateway(session, board)
agent = Agent(
name=payload.name,
status="provisioning",
@@ -640,13 +620,13 @@ async def heartbeat_or_create_agent(
)
raw_token = generate_agent_token()
agent.agent_token_hash = hash_agent_token(raw_token)
agent.provision_requested_at = datetime.utcnow()
agent.provision_requested_at = utcnow()
agent.provision_action = "provision"
session_key, session_error = await _ensure_gateway_session(agent.name, client_config)
agent.openclaw_session_id = session_key
session.add(agent)
session.commit()
session.refresh(agent)
await session.commit()
await session.refresh(agent)
if session_error:
record_activity(
session,
@@ -661,16 +641,16 @@ async def heartbeat_or_create_agent(
message=f"Session created for {agent.name}.",
agent_id=agent.id,
)
session.commit()
await session.commit()
try:
await provision_agent(agent, board, gateway, raw_token, actor.user, action="provision")
await _send_wakeup_message(agent, client_config, verb="provisioned")
agent.provision_confirm_token_hash = None
agent.provision_requested_at = None
agent.provision_action = None
agent.updated_at = datetime.utcnow()
agent.updated_at = utcnow()
session.add(agent)
session.commit()
await session.commit()
record_activity(
session,
event_type="agent.provision",
@@ -683,13 +663,13 @@ async def heartbeat_or_create_agent(
message=f"Wakeup message sent to {agent.name}.",
agent_id=agent.id,
)
session.commit()
await session.commit()
except OpenClawGatewayError as exc:
_record_instruction_failure(session, agent, str(exc), "provision")
session.commit()
await session.commit()
except Exception as exc: # pragma: no cover - unexpected provisioning errors
_record_instruction_failure(session, agent, str(exc), "provision")
session.commit()
await session.commit()
elif actor.actor_type == "agent" and actor.agent and actor.agent.id != agent.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
elif agent.agent_token_hash is None and actor.actor_type == "user":
@@ -697,22 +677,22 @@ async def heartbeat_or_create_agent(
agent.agent_token_hash = hash_agent_token(raw_token)
if agent.heartbeat_config is None:
agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy()
agent.provision_requested_at = datetime.utcnow()
agent.provision_requested_at = utcnow()
agent.provision_action = "provision"
session.add(agent)
session.commit()
session.refresh(agent)
await session.commit()
await session.refresh(agent)
try:
board = _require_board(session, str(agent.board_id) if agent.board_id else None)
gateway, client_config = _require_gateway(session, board)
board = await _require_board(session, str(agent.board_id) if agent.board_id else None)
gateway, client_config = await _require_gateway(session, board)
await provision_agent(agent, board, gateway, raw_token, actor.user, action="provision")
await _send_wakeup_message(agent, client_config, verb="provisioned")
agent.provision_confirm_token_hash = None
agent.provision_requested_at = None
agent.provision_action = None
agent.updated_at = datetime.utcnow()
agent.updated_at = utcnow()
session.add(agent)
session.commit()
await session.commit()
record_activity(
session,
event_type="agent.provision",
@@ -725,16 +705,16 @@ async def heartbeat_or_create_agent(
message=f"Wakeup message sent to {agent.name}.",
agent_id=agent.id,
)
session.commit()
await session.commit()
except OpenClawGatewayError as exc:
_record_instruction_failure(session, agent, str(exc), "provision")
session.commit()
await session.commit()
except Exception as exc: # pragma: no cover - unexpected provisioning errors
_record_instruction_failure(session, agent, str(exc), "provision")
session.commit()
await session.commit()
elif not agent.openclaw_session_id:
board = _require_board(session, str(agent.board_id) if agent.board_id else None)
gateway, client_config = _require_gateway(session, board)
board = await _require_board(session, str(agent.board_id) if agent.board_id else None)
gateway, client_config = await _require_gateway(session, board)
session_key, session_error = await _ensure_gateway_session(agent.name, client_config)
agent.openclaw_session_id = session_key
if session_error:
@@ -751,47 +731,45 @@ async def heartbeat_or_create_agent(
message=f"Session created for {agent.name}.",
agent_id=agent.id,
)
session.commit()
await session.commit()
if payload.status:
agent.status = payload.status
elif agent.status == "provisioning":
agent.status = "online"
agent.last_seen_at = datetime.utcnow()
agent.updated_at = datetime.utcnow()
agent.last_seen_at = utcnow()
agent.updated_at = utcnow()
_record_heartbeat(session, agent)
session.add(agent)
session.commit()
session.refresh(agent)
main_session_keys = _get_gateway_main_session_keys(session)
await session.commit()
await session.refresh(agent)
main_session_keys = await _get_gateway_main_session_keys(session)
return _to_agent_read(_with_computed_status(agent), main_session_keys)
@router.delete("/{agent_id}")
def delete_agent(
@router.delete("/{agent_id}", response_model=OkResponse)
async def delete_agent(
agent_id: str,
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(require_admin_auth),
) -> dict[str, bool]:
agent = session.get(Agent, agent_id)
) -> OkResponse:
agent = await session.get(Agent, agent_id)
if agent is None:
return {"ok": True}
return OkResponse()
board = _require_board(session, str(agent.board_id) if agent.board_id else None)
gateway, client_config = _require_gateway(session, board)
board = await _require_board(session, str(agent.board_id) if agent.board_id else None)
gateway, client_config = await _require_gateway(session, board)
try:
import asyncio
workspace_path = asyncio.run(cleanup_agent(agent, gateway))
workspace_path = await cleanup_agent(agent, gateway)
except OpenClawGatewayError as exc:
_record_instruction_failure(session, agent, str(exc), "delete")
session.commit()
await session.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Gateway cleanup failed: {exc}",
) from exc
except Exception as exc: # pragma: no cover - unexpected cleanup errors
_record_instruction_failure(session, agent, str(exc), "delete")
session.commit()
await session.commit()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Workspace cleanup failed: {exc}",
@@ -804,7 +782,7 @@ def delete_agent(
agent_id=None,
)
now = datetime.now()
session.execute(
await session.execute(
update(Task)
.where(col(Task.assigned_agent_id) == agent.id)
.where(col(Task.status) == "in_progress")
@@ -815,7 +793,7 @@ def delete_agent(
updated_at=now,
)
)
session.execute(
await session.execute(
update(Task)
.where(col(Task.assigned_agent_id) == agent.id)
.where(col(Task.status) != "in_progress")
@@ -824,11 +802,11 @@ def delete_agent(
updated_at=now,
)
)
session.execute(
await session.execute(
update(ActivityEvent).where(col(ActivityEvent.agent_id) == agent.id).values(agent_id=None)
)
session.delete(agent)
session.commit()
await session.delete(agent)
await session.commit()
# Always ask the main agent to confirm workspace cleanup.
try:
@@ -843,20 +821,14 @@ def delete_agent(
"1) Remove the workspace directory.\n"
"2) Reply NO_REPLY.\n"
)
async def _request_cleanup() -> None:
await ensure_session(main_session, config=client_config, label="Main Agent")
await send_message(
cleanup_message,
session_key=main_session,
config=client_config,
deliver=False,
)
import asyncio
asyncio.run(_request_cleanup())
await ensure_session(main_session, config=client_config, label="Main Agent")
await send_message(
cleanup_message,
session_key=main_session,
config=client_config,
deliver=False,
)
except Exception:
# Cleanup request is best-effort; deletion already completed.
pass
return {"ok": True}
return OkResponse()

View File

@@ -2,24 +2,26 @@ from __future__ import annotations
import asyncio
import json
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import asc, or_
from sqlmodel import Session, col, select
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sse_starlette.sse import EventSourceResponse
from starlette.concurrency import run_in_threadpool
from app.api.deps import ActorContext, get_board_or_404, require_admin_auth, require_admin_or_agent
from app.db.session import engine, get_session
from app.core.auth import AuthContext
from app.core.time import utcnow
from app.db.session import async_session_maker, get_session
from app.models.approvals import Approval
from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalUpdate
from app.models.boards import Board
from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalStatus, ApprovalUpdate
router = APIRouter(prefix="/boards/{board_id}/approvals", tags=["approvals"])
ALLOWED_STATUSES = {"pending", "approved", "rejected"}
def _parse_since(value: str | None) -> datetime | None:
if not value:
@@ -45,30 +47,30 @@ def _serialize_approval(approval: Approval) -> dict[str, object]:
return ApprovalRead.model_validate(approval, from_attributes=True).model_dump(mode="json")
def _fetch_approval_events(
async def _fetch_approval_events(
session: AsyncSession,
board_id: UUID,
since: datetime,
) -> list[Approval]:
with Session(engine) as session:
statement = (
select(Approval)
.where(col(Approval.board_id) == board_id)
.where(
or_(
col(Approval.created_at) >= since,
col(Approval.resolved_at) >= since,
)
statement = (
select(Approval)
.where(col(Approval.board_id) == board_id)
.where(
or_(
col(Approval.created_at) >= since,
col(Approval.resolved_at) >= since,
)
.order_by(asc(col(Approval.created_at)))
)
return list(session.exec(statement))
.order_by(asc(col(Approval.created_at)))
)
return list(await session.exec(statement))
@router.get("", response_model=list[ApprovalRead])
def list_approvals(
status_filter: str | None = Query(default=None, alias="status"),
board=Depends(get_board_or_404),
session: Session = Depends(get_session),
async def list_approvals(
status_filter: ApprovalStatus | None = Query(default=None, alias="status"),
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> list[Approval]:
if actor.actor_type == "agent" and actor.agent:
@@ -76,32 +78,31 @@ def list_approvals(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
statement = select(Approval).where(col(Approval.board_id) == board.id)
if status_filter:
if status_filter not in ALLOWED_STATUSES:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
statement = statement.where(col(Approval.status) == status_filter)
statement = statement.order_by(col(Approval.created_at).desc())
return list(session.exec(statement))
return list(await session.exec(statement))
@router.get("/stream")
async def stream_approvals(
request: Request,
board=Depends(get_board_or_404),
board: Board = Depends(get_board_or_404),
actor: ActorContext = Depends(require_admin_or_agent),
since: str | None = Query(default=None),
) -> EventSourceResponse:
if actor.actor_type == "agent" and actor.agent:
if actor.agent.board_id and actor.agent.board_id != board.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
since_dt = _parse_since(since) or datetime.utcnow()
since_dt = _parse_since(since) or utcnow()
last_seen = since_dt
async def event_generator():
async def event_generator() -> AsyncIterator[dict[str, str]]:
nonlocal last_seen
while True:
if await request.is_disconnected():
break
approvals = await run_in_threadpool(_fetch_approval_events, board.id, last_seen)
async with async_session_maker() as session:
approvals = await _fetch_approval_events(session, board.id, last_seen)
for approval in approvals:
updated_at = _approval_updated_at(approval)
if updated_at > last_seen:
@@ -114,10 +115,10 @@ async def stream_approvals(
@router.post("", response_model=ApprovalRead)
def create_approval(
async def create_approval(
payload: ApprovalCreate,
board=Depends(get_board_or_404),
session: Session = Depends(get_session),
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> Approval:
if actor.actor_type == "agent" and actor.agent:
@@ -133,30 +134,28 @@ def create_approval(
status=payload.status,
)
session.add(approval)
session.commit()
session.refresh(approval)
await session.commit()
await session.refresh(approval)
return approval
@router.patch("/{approval_id}", response_model=ApprovalRead)
def update_approval(
async def update_approval(
approval_id: str,
payload: ApprovalUpdate,
board=Depends(get_board_or_404),
session: Session = Depends(get_session),
auth=Depends(require_admin_auth),
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(require_admin_auth),
) -> Approval:
approval = session.get(Approval, approval_id)
approval = await session.get(Approval, approval_id)
if approval is None or approval.board_id != board.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
updates = payload.model_dump(exclude_unset=True)
if "status" in updates:
if updates["status"] not in ALLOWED_STATUSES:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
approval.status = updates["status"]
if approval.status != "pending":
approval.resolved_at = datetime.utcnow()
approval.resolved_at = utcnow()
session.add(approval)
session.commit()
session.refresh(approval)
await session.commit()
await session.refresh(approval)
return approval

View File

@@ -3,20 +3,24 @@ from __future__ import annotations
import asyncio
import json
import re
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlmodel import Session, col, select
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sse_starlette.sse import EventSourceResponse
from starlette.concurrency import run_in_threadpool
from app.api.deps import ActorContext, get_board_or_404, require_admin_or_agent
from app.core.config import settings
from app.db.session import engine, get_session
from app.core.time import utcnow
from app.db.session import async_session_maker, get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
from app.models.agents import Agent
from app.models.board_memory import BoardMemory
from app.models.boards import Board
from app.models.gateways import Gateway
from app.schemas.board_memory import BoardMemoryCreate, BoardMemoryRead
@@ -62,10 +66,10 @@ def _matches_mention(agent: Agent, mentions: set[str]) -> bool:
return first in mentions
def _gateway_config(session: Session, board) -> GatewayClientConfig | None:
if not board.gateway_id:
async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None:
if board.gateway_id is None:
return None
gateway = session.get(Gateway, board.gateway_id)
gateway = await session.get(Gateway, board.gateway_id)
if gateway is None or not gateway.url:
return None
return GatewayClientConfig(url=gateway.url, token=gateway.token)
@@ -82,36 +86,36 @@ async def _send_agent_message(
await send_message(message, session_key=session_key, config=config, deliver=False)
def _fetch_memory_events(
board_id,
async def _fetch_memory_events(
session: AsyncSession,
board_id: UUID,
since: datetime,
) -> list[BoardMemory]:
with Session(engine) as session:
statement = (
select(BoardMemory)
.where(col(BoardMemory.board_id) == board_id)
.where(col(BoardMemory.created_at) >= since)
.order_by(col(BoardMemory.created_at))
)
return list(session.exec(statement))
statement = (
select(BoardMemory)
.where(col(BoardMemory.board_id) == board_id)
.where(col(BoardMemory.created_at) >= since)
.order_by(col(BoardMemory.created_at))
)
return list(await session.exec(statement))
def _notify_chat_targets(
async def _notify_chat_targets(
*,
session: Session,
board,
session: AsyncSession,
board: Board,
memory: BoardMemory,
actor: ActorContext,
) -> None:
if not memory.content:
return
config = _gateway_config(session, board)
config = await _gateway_config(session, board)
if config is None:
return
mentions = _extract_mentions(memory.content)
statement = select(Agent).where(col(Agent.board_id) == board.id)
targets: dict[str, Agent] = {}
for agent in session.exec(statement):
for agent in await session.exec(statement):
if agent.is_board_lead:
targets[str(agent.id)] = agent
continue
@@ -145,24 +149,22 @@ def _notify_chat_targets(
'Body: {"content":"...","tags":["chat"]}'
)
try:
asyncio.run(
_send_agent_message(
session_key=agent.openclaw_session_id,
config=config,
agent_name=agent.name,
message=message,
)
await _send_agent_message(
session_key=agent.openclaw_session_id,
config=config,
agent_name=agent.name,
message=message,
)
except OpenClawGatewayError:
continue
@router.get("", response_model=list[BoardMemoryRead])
def list_board_memory(
async def list_board_memory(
limit: int = Query(default=50, ge=1, le=200),
offset: int = Query(default=0, ge=0),
board=Depends(get_board_or_404),
session: Session = Depends(get_session),
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> list[BoardMemory]:
if actor.actor_type == "agent" and actor.agent:
@@ -175,28 +177,29 @@ def list_board_memory(
.offset(offset)
.limit(limit)
)
return list(session.exec(statement))
return list(await session.exec(statement))
@router.get("/stream")
async def stream_board_memory(
request: Request,
board=Depends(get_board_or_404),
board: Board = Depends(get_board_or_404),
actor: ActorContext = Depends(require_admin_or_agent),
since: str | None = Query(default=None),
) -> EventSourceResponse:
if actor.actor_type == "agent" and actor.agent:
if actor.agent.board_id and actor.agent.board_id != board.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
since_dt = _parse_since(since) or datetime.utcnow()
since_dt = _parse_since(since) or utcnow()
last_seen = since_dt
async def event_generator():
async def event_generator() -> AsyncIterator[dict[str, str]]:
nonlocal last_seen
while True:
if await request.is_disconnected():
break
memories = await run_in_threadpool(_fetch_memory_events, board.id, last_seen)
async with async_session_maker() as session:
memories = await _fetch_memory_events(session, board.id, last_seen)
for memory in memories:
if memory.created_at > last_seen:
last_seen = memory.created_at
@@ -208,10 +211,10 @@ async def stream_board_memory(
@router.post("", response_model=BoardMemoryRead)
def create_board_memory(
async def create_board_memory(
payload: BoardMemoryCreate,
board=Depends(get_board_or_404),
session: Session = Depends(get_session),
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> BoardMemory:
if actor.actor_type == "agent" and actor.agent:
@@ -231,8 +234,8 @@ def create_board_memory(
source=source,
)
session.add(memory)
session.commit()
session.refresh(memory)
await session.commit()
await session.refresh(memory)
if is_chat:
_notify_chat_targets(session=session, board=board, memory=memory, actor=actor)
await _notify_chat_targets(session=session, board=board, memory=memory, actor=actor)
return memory

View File

@@ -1,18 +1,19 @@
from __future__ import annotations
import json
import logging
import re
from datetime import datetime
from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel import Session, select
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import ActorContext, get_board_or_404, require_admin_auth, require_admin_or_agent
from app.core.agent_tokens import generate_agent_token, hash_agent_token
from app.core.auth import AuthContext
from app.core.config import settings
from app.core.time import utcnow
from app.db.session import get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
@@ -22,6 +23,8 @@ from app.models.boards import Board
from app.models.gateways import Gateway
from app.schemas.board_onboarding import (
BoardOnboardingAnswer,
BoardOnboardingAgentComplete,
BoardOnboardingAgentUpdate,
BoardOnboardingConfirm,
BoardOnboardingRead,
BoardOnboardingStart,
@@ -33,10 +36,12 @@ router = APIRouter(prefix="/boards/{board_id}/onboarding", tags=["board-onboardi
logger = logging.getLogger(__name__)
def _gateway_config(session: Session, board: Board) -> tuple[Gateway, GatewayClientConfig]:
async def _gateway_config(
session: AsyncSession, board: Board
) -> tuple[Gateway, GatewayClientConfig]:
if not board.gateway_id:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
gateway = session.get(Gateway, board.gateway_id)
gateway = await session.get(Gateway, board.gateway_id)
if gateway is None or not gateway.url or not gateway.main_session_key:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
return gateway, GatewayClientConfig(url=gateway.url, token=gateway.token)
@@ -56,21 +61,25 @@ def _lead_session_key(board: Board) -> str:
async def _ensure_lead_agent(
session: Session,
session: AsyncSession,
board: Board,
gateway: Gateway,
config: GatewayClientConfig,
auth: AuthContext,
) -> Agent:
existing = session.exec(
select(Agent).where(Agent.board_id == board.id).where(Agent.is_board_lead.is_(True))
existing = (
await session.exec(
select(Agent)
.where(Agent.board_id == board.id)
.where(col(Agent.is_board_lead).is_(True))
)
).first()
if existing:
if existing.name != _lead_agent_name(board):
existing.name = _lead_agent_name(board)
session.add(existing)
session.commit()
session.refresh(existing)
await session.commit()
await session.refresh(existing)
return existing
agent = Agent(
@@ -87,12 +96,12 @@ async def _ensure_lead_agent(
)
raw_token = generate_agent_token()
agent.agent_token_hash = hash_agent_token(raw_token)
agent.provision_requested_at = datetime.utcnow()
agent.provision_requested_at = utcnow()
agent.provision_action = "provision"
agent.openclaw_session_id = _lead_session_key(board)
session.add(agent)
session.commit()
session.refresh(agent)
await session.commit()
await session.refresh(agent)
try:
await provision_agent(agent, board, gateway, raw_token, auth.user, action="provision")
@@ -114,15 +123,17 @@ async def _ensure_lead_agent(
@router.get("", response_model=BoardOnboardingRead)
def get_onboarding(
async def get_onboarding(
board: Board = Depends(get_board_or_404),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(require_admin_auth),
) -> BoardOnboardingSession:
onboarding = session.exec(
select(BoardOnboardingSession)
.where(BoardOnboardingSession.board_id == board.id)
.order_by(BoardOnboardingSession.created_at.desc())
onboarding = (
await session.exec(
select(BoardOnboardingSession)
.where(BoardOnboardingSession.board_id == board.id)
.order_by(col(BoardOnboardingSession.created_at).desc())
)
).first()
if onboarding is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -133,18 +144,20 @@ def get_onboarding(
async def start_onboarding(
payload: BoardOnboardingStart,
board: Board = Depends(get_board_or_404),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(require_admin_auth),
) -> BoardOnboardingSession:
onboarding = session.exec(
select(BoardOnboardingSession)
.where(BoardOnboardingSession.board_id == board.id)
.where(BoardOnboardingSession.status == "active")
onboarding = (
await session.exec(
select(BoardOnboardingSession)
.where(BoardOnboardingSession.board_id == board.id)
.where(BoardOnboardingSession.status == "active")
)
).first()
if onboarding:
return onboarding
gateway, config = _gateway_config(session, board)
gateway, config = await _gateway_config(session, board)
session_key = gateway.main_session_key
base_url = settings.base_url or "http://localhost:8000"
prompt = (
@@ -185,11 +198,11 @@ async def start_onboarding(
board_id=board.id,
session_key=session_key,
status="active",
messages=[{"role": "user", "content": prompt, "timestamp": datetime.utcnow().isoformat()}],
messages=[{"role": "user", "content": prompt, "timestamp": utcnow().isoformat()}],
)
session.add(onboarding)
session.commit()
session.refresh(onboarding)
await session.commit()
await session.refresh(onboarding)
return onboarding
@@ -197,25 +210,27 @@ async def start_onboarding(
async def answer_onboarding(
payload: BoardOnboardingAnswer,
board: Board = Depends(get_board_or_404),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(require_admin_auth),
) -> BoardOnboardingSession:
onboarding = session.exec(
select(BoardOnboardingSession)
.where(BoardOnboardingSession.board_id == board.id)
.order_by(BoardOnboardingSession.created_at.desc())
onboarding = (
await session.exec(
select(BoardOnboardingSession)
.where(BoardOnboardingSession.board_id == board.id)
.order_by(col(BoardOnboardingSession.created_at).desc())
)
).first()
if onboarding is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
_, config = _gateway_config(session, board)
_, config = await _gateway_config(session, board)
answer_text = payload.answer
if payload.other_text:
answer_text = f"{payload.answer}: {payload.other_text}"
messages = list(onboarding.messages or [])
messages.append(
{"role": "user", "content": answer_text, "timestamp": datetime.utcnow().isoformat()}
{"role": "user", "content": answer_text, "timestamp": utcnow().isoformat()}
)
try:
@@ -227,18 +242,18 @@ async def answer_onboarding(
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
onboarding.messages = messages
onboarding.updated_at = datetime.utcnow()
onboarding.updated_at = utcnow()
session.add(onboarding)
session.commit()
session.refresh(onboarding)
await session.commit()
await session.refresh(onboarding)
return onboarding
@router.post("/agent", response_model=BoardOnboardingRead)
def agent_onboarding_update(
payload: dict[str, object],
async def agent_onboarding_update(
payload: BoardOnboardingAgentUpdate,
board: Board = Depends(get_board_or_404),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> BoardOnboardingSession:
if actor.actor_type != "agent" or actor.agent is None:
@@ -248,15 +263,17 @@ def agent_onboarding_update(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if board.gateway_id:
gateway = session.get(Gateway, board.gateway_id)
gateway = await session.get(Gateway, board.gateway_id)
if gateway and gateway.main_session_key and agent.openclaw_session_id:
if agent.openclaw_session_id != gateway.main_session_key:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
onboarding = session.exec(
select(BoardOnboardingSession)
.where(BoardOnboardingSession.board_id == board.id)
.order_by(BoardOnboardingSession.created_at.desc())
onboarding = (
await session.exec(
select(BoardOnboardingSession)
.where(BoardOnboardingSession.board_id == board.id)
.order_by(col(BoardOnboardingSession.created_at).desc())
)
).first()
if onboarding is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -264,31 +281,27 @@ def agent_onboarding_update(
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
messages = list(onboarding.messages or [])
now = datetime.utcnow().isoformat()
payload_text = json.dumps(payload)
now = utcnow().isoformat()
payload_text = payload.model_dump_json(exclude_none=True)
payload_data = payload.model_dump(mode="json", exclude_none=True)
logger.info(
"onboarding.agent.update board_id=%s agent_id=%s payload=%s",
board.id,
agent.id,
payload_text,
)
payload_status = payload.get("status")
if payload_status == "complete":
onboarding.draft_goal = payload
if isinstance(payload, BoardOnboardingAgentComplete):
onboarding.draft_goal = payload_data
onboarding.status = "completed"
messages.append({"role": "assistant", "content": payload_text, "timestamp": now})
else:
question = payload.get("question")
options = payload.get("options")
if not isinstance(question, str) or not isinstance(options, list):
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
messages.append({"role": "assistant", "content": payload_text, "timestamp": now})
onboarding.messages = messages
onboarding.updated_at = datetime.utcnow()
onboarding.updated_at = utcnow()
session.add(onboarding)
session.commit()
session.refresh(onboarding)
await session.commit()
await session.refresh(onboarding)
logger.info(
"onboarding.agent.update stored board_id=%s messages_count=%s status=%s",
board.id,
@@ -302,13 +315,15 @@ def agent_onboarding_update(
async def confirm_onboarding(
payload: BoardOnboardingConfirm,
board: Board = Depends(get_board_or_404),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(require_admin_auth),
) -> Board:
onboarding = session.exec(
select(BoardOnboardingSession)
.where(BoardOnboardingSession.board_id == board.id)
.order_by(BoardOnboardingSession.created_at.desc())
onboarding = (
await session.exec(
select(BoardOnboardingSession)
.where(BoardOnboardingSession.board_id == board.id)
.order_by(col(BoardOnboardingSession.created_at).desc())
)
).first()
if onboarding is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -321,12 +336,12 @@ async def confirm_onboarding(
board.goal_source = "lead_agent_onboarding"
onboarding.status = "confirmed"
onboarding.updated_at = datetime.utcnow()
onboarding.updated_at = utcnow()
gateway, config = _gateway_config(session, board)
gateway, config = await _gateway_config(session, board)
session.add(board)
session.add(onboarding)
session.commit()
session.refresh(board)
await session.commit()
await session.refresh(board)
await _ensure_lead_agent(session, board, gateway, config, auth)
return board

View File

@@ -1,15 +1,17 @@
from __future__ import annotations
import asyncio
import re
from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import delete
from sqlmodel import Session, col, select
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import ActorContext, get_board_or_404, require_admin_auth, require_admin_or_agent
from app.core.auth import AuthContext
from app.core.time import utcnow
from app.db import crud
from app.db.session import get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import (
@@ -27,6 +29,7 @@ from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.task_fingerprints import TaskFingerprint
from app.models.tasks import Task
from app.schemas.common import OkResponse
from app.schemas.boards import BoardCreate, BoardRead, BoardUpdate
router = APIRouter(prefix="/boards", tags=["boards"])
@@ -43,12 +46,56 @@ def _build_session_key(agent_name: str) -> str:
return f"{AGENT_SESSION_PREFIX}:{_slugify(agent_name)}:main"
def _board_gateway(
session: Session, board: Board
async def _require_gateway(session: AsyncSession, gateway_id: object) -> Gateway:
gateway = await crud.get_by_id(session, Gateway, gateway_id)
if gateway is None:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="gateway_id is invalid",
)
return gateway
async def _require_gateway_for_create(
payload: BoardCreate,
session: AsyncSession = Depends(get_session),
) -> Gateway:
return await _require_gateway(session, payload.gateway_id)
async def _apply_board_update(
*,
payload: BoardUpdate,
session: AsyncSession,
board: Board,
) -> Board:
updates = payload.model_dump(exclude_unset=True)
if "gateway_id" in updates:
await _require_gateway(session, updates["gateway_id"])
for key, value in updates.items():
setattr(board, key, value)
if updates.get("board_type") == "goal":
# Validate only when explicitly switching to goal boards.
if not board.objective or not board.success_metrics:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Goal boards require objective and success_metrics",
)
if not board.gateway_id:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="gateway_id is required",
)
board.updated_at = utcnow()
return await crud.save(session, board)
async def _board_gateway(
session: AsyncSession, board: Board
) -> tuple[Gateway | None, GatewayClientConfig | None]:
if not board.gateway_id:
return None, None
config = session.get(Gateway, board.gateway_id)
config = await session.get(Gateway, board.gateway_id)
if config is None:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -103,36 +150,21 @@ async def _cleanup_agent_on_gateway(
@router.get("", response_model=list[BoardRead])
def list_boards(
session: Session = Depends(get_session),
async def list_boards(
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> list[Board]:
return list(session.exec(select(Board)))
return list(await session.exec(select(Board)))
@router.post("", response_model=BoardRead)
def create_board(
async def create_board(
payload: BoardCreate,
session: Session = Depends(get_session),
_gateway: Gateway = Depends(_require_gateway_for_create),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(require_admin_auth),
) -> Board:
data = payload.model_dump()
if not data.get("gateway_id"):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="gateway_id is required",
)
config = session.get(Gateway, data["gateway_id"])
if config is None:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="gateway_id is invalid",
)
board = Board.model_validate(data)
session.add(board)
session.commit()
session.refresh(board)
return board
return await crud.create(session, Board, **payload.model_dump())
@router.get("/{board_id}", response_model=BoardRead)
@@ -144,60 +176,29 @@ def get_board(
@router.patch("/{board_id}", response_model=BoardRead)
def update_board(
async def update_board(
payload: BoardUpdate,
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
board: Board = Depends(get_board_or_404),
auth: AuthContext = Depends(require_admin_auth),
) -> Board:
updates = payload.model_dump(exclude_unset=True)
if "gateway_id" in updates:
if not updates.get("gateway_id"):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="gateway_id is required",
)
config = session.get(Gateway, updates["gateway_id"])
if config is None:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="gateway_id is invalid",
)
for key, value in updates.items():
setattr(board, key, value)
if updates.get("board_type") == "goal":
objective = updates.get("objective") or board.objective
metrics = updates.get("success_metrics") or board.success_metrics
if not objective or not metrics:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Goal boards require objective and success_metrics",
)
if not board.gateway_id:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="gateway_id is required",
)
session.add(board)
session.commit()
session.refresh(board)
return board
return await _apply_board_update(payload=payload, session=session, board=board)
@router.delete("/{board_id}")
def delete_board(
session: Session = Depends(get_session),
@router.delete("/{board_id}", response_model=OkResponse)
async def delete_board(
session: AsyncSession = Depends(get_session),
board: Board = Depends(get_board_or_404),
auth: AuthContext = Depends(require_admin_auth),
) -> dict[str, bool]:
agents = list(session.exec(select(Agent).where(Agent.board_id == board.id)))
task_ids = list(session.exec(select(Task.id).where(Task.board_id == board.id)))
) -> OkResponse:
agents = list(await session.exec(select(Agent).where(Agent.board_id == board.id)))
task_ids = list(await session.exec(select(Task.id).where(Task.board_id == board.id)))
config, client_config = _board_gateway(session, board)
config, client_config = await _board_gateway(session, board)
if config and client_config:
try:
for agent in agents:
asyncio.run(_cleanup_agent_on_gateway(agent, config, client_config))
await _cleanup_agent_on_gateway(agent, config, client_config)
except OpenClawGatewayError as exc:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
@@ -205,18 +206,18 @@ def delete_board(
) from exc
if task_ids:
session.execute(delete(ActivityEvent).where(col(ActivityEvent.task_id).in_(task_ids)))
session.execute(delete(TaskFingerprint).where(col(TaskFingerprint.board_id) == board.id))
await session.execute(delete(ActivityEvent).where(col(ActivityEvent.task_id).in_(task_ids)))
await session.execute(delete(TaskFingerprint).where(col(TaskFingerprint.board_id) == board.id))
if agents:
agent_ids = [agent.id for agent in agents]
session.execute(delete(ActivityEvent).where(col(ActivityEvent.agent_id).in_(agent_ids)))
session.execute(delete(Agent).where(col(Agent.id).in_(agent_ids)))
session.execute(delete(Approval).where(col(Approval.board_id) == board.id))
session.execute(delete(BoardMemory).where(col(BoardMemory.board_id) == board.id))
session.execute(
await session.execute(delete(ActivityEvent).where(col(ActivityEvent.agent_id).in_(agent_ids)))
await session.execute(delete(Agent).where(col(Agent.id).in_(agent_ids)))
await session.execute(delete(Approval).where(col(Approval.board_id) == board.id))
await session.execute(delete(BoardMemory).where(col(BoardMemory.board_id) == board.id))
await session.execute(
delete(BoardOnboardingSession).where(col(BoardOnboardingSession.board_id) == board.id)
)
session.execute(delete(Task).where(col(Task.board_id) == board.id))
session.delete(board)
session.commit()
return {"ok": True}
await session.execute(delete(Task).where(col(Task.board_id) == board.id))
await session.delete(board)
await session.commit()
return OkResponse()

View File

@@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Literal
from fastapi import Depends, HTTPException, status
from sqlmodel import Session
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.agent_auth import AgentAuthContext, get_agent_auth_context_optional
from app.core.auth import AuthContext, get_auth_context, get_auth_context_optional
@@ -40,22 +40,22 @@ def require_admin_or_agent(
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
def get_board_or_404(
async def get_board_or_404(
board_id: str,
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
) -> Board:
board = session.get(Board, board_id)
board = await session.get(Board, board_id)
if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return board
def get_task_or_404(
async def get_task_or_404(
task_id: str,
board: Board = Depends(get_board_or_404),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
) -> Task:
task = session.get(Task, task_id)
task = await session.get(Task, task_id)
if task is None or task.board_id != board.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return task

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
from sqlmodel import Session
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.auth import AuthContext, get_auth_context
from app.db.session import get_session
@@ -20,12 +20,22 @@ from app.integrations.openclaw_gateway_protocol import (
)
from app.models.boards import Board
from app.models.gateways import Gateway
from app.schemas.common import OkResponse
from app.schemas.gateway_api import (
GatewayCommandsResponse,
GatewayResolveQuery,
GatewaySessionHistoryResponse,
GatewaySessionMessageRequest,
GatewaySessionResponse,
GatewaySessionsResponse,
GatewaysStatusResponse,
)
router = APIRouter(prefix="/gateways", tags=["gateways"])
def _resolve_gateway(
session: Session,
async def _resolve_gateway(
session: AsyncSession,
board_id: str | None,
gateway_url: str | None,
gateway_token: str | None,
@@ -42,7 +52,7 @@ def _resolve_gateway(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="board_id or gateway_url is required",
)
board = session.get(Board, board_id)
board = await session.get(Board, board_id)
if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found")
if not board.gateway_id:
@@ -50,7 +60,7 @@ def _resolve_gateway(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Board gateway_id is required",
)
gateway = session.get(Gateway, board.gateway_id)
gateway = await session.get(Gateway, board.gateway_id)
if gateway is None:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -68,10 +78,10 @@ def _resolve_gateway(
)
def _require_gateway(
session: Session, board_id: str | None
async def _require_gateway(
session: AsyncSession, board_id: str | None
) -> tuple[Board, GatewayClientConfig, str | None]:
board, config, main_session = _resolve_gateway(session, board_id, None, None, None)
board, config, main_session = await _resolve_gateway(session, board_id, None, None, None)
if board is None:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -80,21 +90,18 @@ def _require_gateway(
return board, config, main_session
@router.get("/status")
@router.get("/status", response_model=GatewaysStatusResponse)
async def gateways_status(
board_id: str | None = Query(default=None),
gateway_url: str | None = Query(default=None),
gateway_token: str | None = Query(default=None),
gateway_main_session_key: str | None = Query(default=None),
session: Session = Depends(get_session),
params: GatewayResolveQuery = Depends(),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
) -> dict[str, object]:
board, config, main_session = _resolve_gateway(
) -> GatewaysStatusResponse:
board, config, main_session = await _resolve_gateway(
session,
board_id,
gateway_url,
gateway_token,
gateway_main_session_key,
params.board_id,
params.gateway_url,
params.gateway_token,
params.gateway_main_session_key,
)
try:
sessions = await openclaw_call("sessions.list", config=config)
@@ -111,30 +118,26 @@ async def gateways_status(
main_session_entry = ensured.get("entry") or ensured
except OpenClawGatewayError as exc:
main_session_error = str(exc)
return {
"connected": True,
"gateway_url": config.url,
"sessions_count": len(sessions_list),
"sessions": sessions_list,
"main_session_key": main_session,
"main_session": main_session_entry,
"main_session_error": main_session_error,
}
return GatewaysStatusResponse(
connected=True,
gateway_url=config.url,
sessions_count=len(sessions_list),
sessions=sessions_list,
main_session_key=main_session,
main_session=main_session_entry,
main_session_error=main_session_error,
)
except OpenClawGatewayError as exc:
return {
"connected": False,
"gateway_url": config.url,
"error": str(exc),
}
return GatewaysStatusResponse(connected=False, gateway_url=config.url, error=str(exc))
@router.get("/sessions")
@router.get("/sessions", response_model=GatewaySessionsResponse)
async def list_gateway_sessions(
board_id: str | None = Query(default=None),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
) -> dict[str, object]:
board, config, main_session = _resolve_gateway(
) -> GatewaySessionsResponse:
board, config, main_session = await _resolve_gateway(
session,
board_id,
None,
@@ -159,21 +162,21 @@ async def list_gateway_sessions(
except OpenClawGatewayError:
main_session_entry = None
return {
"sessions": sessions_list,
"main_session_key": main_session,
"main_session": main_session_entry,
}
return GatewaySessionsResponse(
sessions=sessions_list,
main_session_key=main_session,
main_session=main_session_entry,
)
@router.get("/sessions/{session_id}")
@router.get("/sessions/{session_id}", response_model=GatewaySessionResponse)
async def get_gateway_session(
session_id: str,
board_id: str | None = Query(default=None),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
) -> dict[str, object]:
board, config, main_session = _resolve_gateway(
) -> GatewaySessionResponse:
board, config, main_session = await _resolve_gateway(
session,
board_id,
None,
@@ -208,55 +211,50 @@ async def get_gateway_session(
session_entry = None
if session_entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
return {"session": session_entry}
return GatewaySessionResponse(session=session_entry)
@router.get("/sessions/{session_id}/history")
@router.get("/sessions/{session_id}/history", response_model=GatewaySessionHistoryResponse)
async def get_session_history(
session_id: str,
board_id: str | None = Query(default=None),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
) -> dict[str, object]:
_, config, _ = _require_gateway(session, board_id)
) -> GatewaySessionHistoryResponse:
_, config, _ = await _require_gateway(session, board_id)
try:
history = await get_chat_history(session_id, config=config)
except OpenClawGatewayError as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
if isinstance(history, dict) and isinstance(history.get("messages"), list):
return {"history": history["messages"]}
return {"history": list(history or [])}
return GatewaySessionHistoryResponse(history=history["messages"])
return GatewaySessionHistoryResponse(history=list(history or []))
@router.post("/sessions/{session_id}/message")
@router.post("/sessions/{session_id}/message", response_model=OkResponse)
async def send_gateway_session_message(
session_id: str,
payload: dict = Body(...),
payload: GatewaySessionMessageRequest,
board_id: str | None = Query(default=None),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
) -> dict[str, bool]:
content = payload.get("content")
if not content:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="content is required"
)
board, config, main_session = _require_gateway(session, board_id)
) -> OkResponse:
board, config, main_session = await _require_gateway(session, board_id)
try:
if main_session and session_id == main_session:
await ensure_session(main_session, config=config, label="Main Agent")
await send_message(content, session_key=session_id, config=config)
await send_message(payload.content, session_key=session_id, config=config)
except OpenClawGatewayError as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
return {"ok": True}
return OkResponse()
@router.get("/commands")
@router.get("/commands", response_model=GatewayCommandsResponse)
async def gateway_commands(
auth: AuthContext = Depends(get_auth_context),
) -> dict[str, object]:
return {
"protocol_version": PROTOCOL_VERSION,
"methods": GATEWAY_METHODS,
"events": GATEWAY_EVENTS,
}
) -> GatewayCommandsResponse:
return GatewayCommandsResponse(
protocol_version=PROTOCOL_VERSION,
methods=GATEWAY_METHODS,
events=GATEWAY_EVENTS,
)

View File

@@ -4,15 +4,18 @@ from datetime import datetime
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel import Session, select
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.agent_tokens import generate_agent_token, hash_agent_token
from app.core.auth import AuthContext, get_auth_context
from app.core.time import utcnow
from app.db.session import get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
from app.models.agents import Agent
from app.models.gateways import Gateway
from app.schemas.common import OkResponse
from app.schemas.gateways import GatewayCreate, GatewayRead, GatewayUpdate
from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_main_agent
@@ -235,21 +238,25 @@ def _main_agent_name(gateway: Gateway) -> str:
return f"{gateway.name} Main"
def _find_main_agent(
session: Session,
async def _find_main_agent(
session: AsyncSession,
gateway: Gateway,
previous_name: str | None = None,
previous_session_key: str | None = None,
) -> Agent | None:
if gateway.main_session_key:
agent = session.exec(
select(Agent).where(Agent.openclaw_session_id == gateway.main_session_key)
agent = (
await session.exec(
select(Agent).where(Agent.openclaw_session_id == gateway.main_session_key)
)
).first()
if agent:
return agent
if previous_session_key:
agent = session.exec(
select(Agent).where(Agent.openclaw_session_id == previous_session_key)
agent = (
await session.exec(
select(Agent).where(Agent.openclaw_session_id == previous_session_key)
)
).first()
if agent:
return agent
@@ -257,14 +264,14 @@ def _find_main_agent(
if previous_name:
names.add(f"{previous_name} Main")
for name in names:
agent = session.exec(select(Agent).where(Agent.name == name)).first()
agent = (await session.exec(select(Agent).where(Agent.name == name))).first()
if agent:
return agent
return None
async def _ensure_main_agent(
session: Session,
session: AsyncSession,
gateway: Gateway,
auth: AuthContext,
*,
@@ -274,7 +281,7 @@ async def _ensure_main_agent(
) -> Agent | None:
if not gateway.url or not gateway.main_session_key:
return None
agent = _find_main_agent(session, gateway, previous_name, previous_session_key)
agent = await _find_main_agent(session, gateway, previous_name, previous_session_key)
if agent is None:
agent = Agent(
name=_main_agent_name(gateway),
@@ -294,14 +301,14 @@ async def _ensure_main_agent(
agent.openclaw_session_id = gateway.main_session_key
raw_token = generate_agent_token()
agent.agent_token_hash = hash_agent_token(raw_token)
agent.provision_requested_at = datetime.utcnow()
agent.provision_requested_at = utcnow()
agent.provision_action = action
agent.updated_at = datetime.utcnow()
agent.updated_at = utcnow()
if agent.heartbeat_config is None:
agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy()
session.add(agent)
session.commit()
session.refresh(agent)
await session.commit()
await session.refresh(agent)
try:
await provision_main_agent(agent, gateway, raw_token, auth.user, action=action)
await ensure_session(
@@ -356,26 +363,24 @@ async def _send_skyll_disable_message(gateway: Gateway) -> None:
@router.get("", response_model=list[GatewayRead])
def list_gateways(
session: Session = Depends(get_session),
async def list_gateways(
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
) -> list[Gateway]:
return list(session.exec(select(Gateway)))
return list(await session.exec(select(Gateway)))
@router.post("", response_model=GatewayRead)
async def create_gateway(
payload: GatewayCreate,
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
) -> Gateway:
data = payload.model_dump()
if data.get("token") == "":
data["token"] = None
gateway = Gateway.model_validate(data)
session.add(gateway)
session.commit()
session.refresh(gateway)
await session.commit()
await session.refresh(gateway)
await _ensure_main_agent(session, gateway, auth, action="provision")
if gateway.skyll_enabled:
try:
@@ -386,12 +391,12 @@ async def create_gateway(
@router.get("/{gateway_id}", response_model=GatewayRead)
def get_gateway(
async def get_gateway(
gateway_id: UUID,
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
) -> Gateway:
gateway = session.get(Gateway, gateway_id)
gateway = await session.get(Gateway, gateway_id)
if gateway is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found")
return gateway
@@ -401,23 +406,21 @@ def get_gateway(
async def update_gateway(
gateway_id: UUID,
payload: GatewayUpdate,
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
) -> Gateway:
gateway = session.get(Gateway, gateway_id)
gateway = await session.get(Gateway, gateway_id)
if gateway is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found")
previous_name = gateway.name
previous_session_key = gateway.main_session_key
previous_skyll_enabled = gateway.skyll_enabled
updates = payload.model_dump(exclude_unset=True)
if updates.get("token") == "":
updates["token"] = None
for key, value in updates.items():
setattr(gateway, key, value)
session.add(gateway)
session.commit()
session.refresh(gateway)
await session.commit()
await session.refresh(gateway)
await _ensure_main_agent(
session,
gateway,
@@ -439,15 +442,15 @@ async def update_gateway(
return gateway
@router.delete("/{gateway_id}")
def delete_gateway(
@router.delete("/{gateway_id}", response_model=OkResponse)
async def delete_gateway(
gateway_id: UUID,
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
) -> dict[str, bool]:
gateway = session.get(Gateway, gateway_id)
) -> OkResponse:
gateway = await session.get(Gateway, gateway_id)
if gateway is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found")
session.delete(gateway)
session.commit()
return {"ok": True}
await session.delete(gateway)
await session.commit()
return OkResponse()

View File

@@ -6,10 +6,12 @@ from typing import Literal
from fastapi import APIRouter, Depends, Query
from sqlalchemy import DateTime, case, cast, func
from sqlmodel import Session, col, select
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import require_admin_auth
from app.core.auth import AuthContext
from app.core.time import utcnow
from app.db.session import get_session
from app.models.activity_events import ActivityEvent
from app.models.agents import Agent
@@ -40,7 +42,7 @@ class RangeSpec:
def _resolve_range(range_key: Literal["24h", "7d"]) -> RangeSpec:
now = datetime.utcnow()
now = utcnow()
if range_key == "7d":
return RangeSpec(
key="7d",
@@ -111,7 +113,7 @@ def _wip_series_from_mapping(
)
def _query_throughput(session: Session, range_spec: RangeSpec) -> DashboardRangeSeries:
async def _query_throughput(session: AsyncSession, range_spec: RangeSpec) -> DashboardRangeSeries:
bucket_col = func.date_trunc(range_spec.bucket, Task.updated_at).label("bucket")
statement = (
select(bucket_col, func.count())
@@ -121,12 +123,12 @@ def _query_throughput(session: Session, range_spec: RangeSpec) -> DashboardRange
.group_by(bucket_col)
.order_by(bucket_col)
)
results = session.exec(statement).all()
results = (await session.exec(statement)).all()
mapping = {row[0]: float(row[1]) for row in results}
return _series_from_mapping(range_spec, mapping)
def _query_cycle_time(session: Session, range_spec: RangeSpec) -> DashboardRangeSeries:
async def _query_cycle_time(session: AsyncSession, range_spec: RangeSpec) -> DashboardRangeSeries:
bucket_col = func.date_trunc(range_spec.bucket, Task.updated_at).label("bucket")
in_progress = cast(Task.in_progress_at, DateTime)
duration_hours = func.extract("epoch", Task.updated_at - in_progress) / 3600.0
@@ -139,12 +141,12 @@ def _query_cycle_time(session: Session, range_spec: RangeSpec) -> DashboardRange
.group_by(bucket_col)
.order_by(bucket_col)
)
results = session.exec(statement).all()
results = (await session.exec(statement)).all()
mapping = {row[0]: float(row[1] or 0) for row in results}
return _series_from_mapping(range_spec, mapping)
def _query_error_rate(session: Session, range_spec: RangeSpec) -> DashboardRangeSeries:
async def _query_error_rate(session: AsyncSession, range_spec: RangeSpec) -> DashboardRangeSeries:
bucket_col = func.date_trunc(range_spec.bucket, ActivityEvent.created_at).label("bucket")
error_case = case(
(
@@ -160,7 +162,7 @@ def _query_error_rate(session: Session, range_spec: RangeSpec) -> DashboardRange
.group_by(bucket_col)
.order_by(bucket_col)
)
results = session.exec(statement).all()
results = (await session.exec(statement)).all()
mapping: dict[datetime, float] = {}
for bucket, errors, total in results:
total_count = float(total or 0)
@@ -170,7 +172,7 @@ def _query_error_rate(session: Session, range_spec: RangeSpec) -> DashboardRange
return _series_from_mapping(range_spec, mapping)
def _query_wip(session: Session, range_spec: RangeSpec) -> DashboardWipRangeSeries:
async def _query_wip(session: AsyncSession, range_spec: RangeSpec) -> DashboardWipRangeSeries:
bucket_col = func.date_trunc(range_spec.bucket, Task.updated_at).label("bucket")
inbox_case = case((col(Task.status) == "inbox", 1), else_=0)
progress_case = case((col(Task.status) == "in_progress", 1), else_=0)
@@ -187,7 +189,7 @@ def _query_wip(session: Session, range_spec: RangeSpec) -> DashboardWipRangeSeri
.group_by(bucket_col)
.order_by(bucket_col)
)
results = session.exec(statement).all()
results = (await session.exec(statement)).all()
mapping: dict[datetime, dict[str, int]] = {}
for bucket, inbox, in_progress, review in results:
mapping[bucket] = {
@@ -198,8 +200,8 @@ def _query_wip(session: Session, range_spec: RangeSpec) -> DashboardWipRangeSeri
return _wip_series_from_mapping(range_spec, mapping)
def _median_cycle_time_7d(session: Session) -> float | None:
now = datetime.utcnow()
async def _median_cycle_time_7d(session: AsyncSession) -> float | None:
now = utcnow()
start = now - timedelta(days=7)
in_progress = cast(Task.in_progress_at, DateTime)
duration_hours = func.extract("epoch", Task.updated_at - in_progress) / 3600.0
@@ -210,7 +212,7 @@ def _median_cycle_time_7d(session: Session) -> float | None:
.where(col(Task.updated_at) >= start)
.where(col(Task.updated_at) <= now)
)
value = session.exec(statement).one_or_none()
value = (await session.exec(statement)).one_or_none()
if value is None:
return None
if isinstance(value, tuple):
@@ -220,7 +222,7 @@ def _median_cycle_time_7d(session: Session) -> float | None:
return float(value)
def _error_rate_kpi(session: Session, range_spec: RangeSpec) -> float:
async def _error_rate_kpi(session: AsyncSession, range_spec: RangeSpec) -> float:
error_case = case(
(
col(ActivityEvent.event_type).like(ERROR_EVENT_PATTERN),
@@ -233,7 +235,7 @@ def _error_rate_kpi(session: Session, range_spec: RangeSpec) -> float:
.where(col(ActivityEvent.created_at) >= range_spec.start)
.where(col(ActivityEvent.created_at) <= range_spec.end)
)
result = session.exec(statement).one_or_none()
result = (await session.exec(statement)).one_or_none()
if result is None:
return 0.0
errors, total = result
@@ -242,58 +244,66 @@ def _error_rate_kpi(session: Session, range_spec: RangeSpec) -> float:
return (error_count / total_count) * 100 if total_count > 0 else 0.0
def _active_agents(session: Session) -> int:
threshold = datetime.utcnow() - OFFLINE_AFTER
async def _active_agents(session: AsyncSession) -> int:
threshold = utcnow() - OFFLINE_AFTER
statement = select(func.count()).where(
col(Agent.last_seen_at).is_not(None),
col(Agent.last_seen_at) >= threshold,
)
result = session.exec(statement).one()
result = (await session.exec(statement)).one()
return int(result)
def _tasks_in_progress(session: Session) -> int:
async def _tasks_in_progress(session: AsyncSession) -> int:
statement = select(func.count()).where(col(Task.status) == "in_progress")
result = session.exec(statement).one()
result = (await session.exec(statement)).one()
return int(result)
@router.get("/dashboard", response_model=DashboardMetrics)
def dashboard_metrics(
async def dashboard_metrics(
range: Literal["24h", "7d"] = Query(default="24h"),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(require_admin_auth),
) -> DashboardMetrics:
primary = _resolve_range(range)
comparison = _comparison_range(range)
throughput_primary = await _query_throughput(session, primary)
throughput_comparison = await _query_throughput(session, comparison)
throughput = DashboardSeriesSet(
primary=_query_throughput(session, primary),
comparison=_query_throughput(session, comparison),
primary=throughput_primary,
comparison=throughput_comparison,
)
cycle_time_primary = await _query_cycle_time(session, primary)
cycle_time_comparison = await _query_cycle_time(session, comparison)
cycle_time = DashboardSeriesSet(
primary=_query_cycle_time(session, primary),
comparison=_query_cycle_time(session, comparison),
primary=cycle_time_primary,
comparison=cycle_time_comparison,
)
error_rate_primary = await _query_error_rate(session, primary)
error_rate_comparison = await _query_error_rate(session, comparison)
error_rate = DashboardSeriesSet(
primary=_query_error_rate(session, primary),
comparison=_query_error_rate(session, comparison),
primary=error_rate_primary,
comparison=error_rate_comparison,
)
wip_primary = await _query_wip(session, primary)
wip_comparison = await _query_wip(session, comparison)
wip = DashboardWipSeriesSet(
primary=_query_wip(session, primary),
comparison=_query_wip(session, comparison),
primary=wip_primary,
comparison=wip_comparison,
)
kpis = DashboardKpis(
active_agents=_active_agents(session),
tasks_in_progress=_tasks_in_progress(session),
error_rate_pct=_error_rate_kpi(session, primary),
median_cycle_time_hours_7d=_median_cycle_time_7d(session),
active_agents=await _active_agents(session),
tasks_in_progress=await _tasks_in_progress(session),
error_rate_pct=await _error_rate_kpi(session, primary),
median_cycle_time_hours_7d=await _median_cycle_time_7d(session),
)
return DashboardMetrics(
range=primary.key,
generated_at=datetime.utcnow(),
generated_at=utcnow(),
kpis=kpis,
throughput=throughput,
cycle_time=cycle_time,

View File

@@ -4,14 +4,17 @@ import asyncio
import json
import re
from collections import deque
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from typing import cast
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import asc, delete, desc
from sqlmodel import Session, col, select
from sqlmodel import col, select
from sqlmodel.sql.expression import Select
from sqlmodel.ext.asyncio.session import AsyncSession
from sse_starlette.sse import EventSourceResponse
from starlette.concurrency import run_in_threadpool
from app.api.deps import (
ActorContext,
@@ -21,7 +24,8 @@ from app.api.deps import (
require_admin_or_agent,
)
from app.core.auth import AuthContext
from app.db.session import engine, get_session
from app.core.time import utcnow
from app.db.session import async_session_maker, get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
from app.models.activity_events import ActivityEvent
@@ -30,6 +34,7 @@ from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.task_fingerprints import TaskFingerprint
from app.models.tasks import Task
from app.schemas.common import OkResponse
from app.schemas.tasks import TaskCommentCreate, TaskCommentRead, TaskCreate, TaskRead, TaskUpdate
from app.services.activity_log import record_activity
@@ -46,14 +51,6 @@ SSE_SEEN_MAX = 2000
MENTION_PATTERN = re.compile(r"@([A-Za-z][\w-]{0,31})")
def validate_task_status(status_value: str) -> None:
if status_value not in ALLOWED_STATUSES:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Unsupported task status.",
)
def _comment_validation_error() -> HTTPException:
return HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -61,8 +58,8 @@ def _comment_validation_error() -> HTTPException:
)
def has_valid_recent_comment(
session: Session,
async def has_valid_recent_comment(
session: AsyncSession,
task: Task,
agent_id: UUID | None,
since: datetime | None,
@@ -77,7 +74,7 @@ def has_valid_recent_comment(
.where(col(ActivityEvent.created_at) >= since)
.order_by(desc(col(ActivityEvent.created_at)))
)
event = session.exec(statement).first()
event = (await session.exec(statement)).first()
if event is None or event.message is None:
return False
return bool(event.message.strip())
@@ -116,8 +113,8 @@ def _matches_mention(agent: Agent, mentions: set[str]) -> bool:
return first in mentions
def _lead_was_mentioned(
session: Session,
async def _lead_was_mentioned(
session: AsyncSession,
task: Task,
lead: Agent,
) -> bool:
@@ -127,7 +124,7 @@ def _lead_was_mentioned(
.where(col(ActivityEvent.event_type) == "task.comment")
.order_by(desc(col(ActivityEvent.created_at)))
)
for message in session.exec(statement):
for message in await session.exec(statement):
if not message:
continue
mentions = _extract_mentions(message)
@@ -142,23 +139,24 @@ def _lead_created_task(task: Task, lead: Agent) -> bool:
return task.auto_reason == f"lead_agent:{lead.id}"
def _fetch_task_events(
async def _fetch_task_events(
session: AsyncSession,
board_id: UUID,
since: datetime,
) -> list[tuple[ActivityEvent, Task | None]]:
with Session(engine) as session:
task_ids = list(session.exec(select(Task.id).where(col(Task.board_id) == board_id)))
if not task_ids:
return []
statement = (
select(ActivityEvent, Task)
.outerjoin(Task, ActivityEvent.task_id == Task.id)
.where(col(ActivityEvent.task_id).in_(task_ids))
.where(col(ActivityEvent.event_type).in_(TASK_EVENT_TYPES))
.where(col(ActivityEvent.created_at) >= since)
.order_by(asc(col(ActivityEvent.created_at)))
)
return list(session.exec(statement))
task_ids = list(await session.exec(select(Task.id).where(col(Task.board_id) == board_id)))
if not task_ids:
return []
statement = cast(
Select[tuple[ActivityEvent, Task | None]],
select(ActivityEvent, Task)
.outerjoin(Task, col(ActivityEvent.task_id) == col(Task.id))
.where(col(ActivityEvent.task_id).in_(task_ids))
.where(col(ActivityEvent.event_type).in_(TASK_EVENT_TYPES))
.where(col(ActivityEvent.created_at) >= since)
.order_by(asc(col(ActivityEvent.created_at))),
)
return list(await session.exec(statement))
def _serialize_task(task: Task | None) -> dict[str, object] | None:
@@ -171,10 +169,10 @@ def _serialize_comment(event: ActivityEvent) -> dict[str, object]:
return TaskCommentRead.model_validate(event).model_dump(mode="json")
def _gateway_config(session: Session, board: Board) -> GatewayClientConfig | None:
async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None:
if not board.gateway_id:
return None
gateway = session.get(Gateway, board.gateway_id)
gateway = await session.get(Gateway, board.gateway_id)
if gateway is None or not gateway.url:
return None
return GatewayClientConfig(url=gateway.url, token=gateway.token)
@@ -201,16 +199,16 @@ async def _send_agent_task_message(
await send_message(message, session_key=session_key, config=config, deliver=False)
def _notify_agent_on_task_assign(
async def _notify_agent_on_task_assign(
*,
session: Session,
session: AsyncSession,
board: Board,
task: Task,
agent: Agent,
) -> None:
if not agent.openclaw_session_id:
return
config = _gateway_config(session, board)
config = await _gateway_config(session, board)
if config is None:
return
description = (task.description or "").strip()
@@ -230,13 +228,11 @@ def _notify_agent_on_task_assign(
+ "\n\nTake action: open the task and begin work. Post updates as task comments."
)
try:
asyncio.run(
_send_agent_task_message(
session_key=agent.openclaw_session_id,
config=config,
agent_name=agent.name,
message=message,
)
await _send_agent_task_message(
session_key=agent.openclaw_session_id,
config=config,
agent_name=agent.name,
message=message,
)
record_activity(
session,
@@ -245,7 +241,7 @@ def _notify_agent_on_task_assign(
agent_id=agent.id,
task_id=task.id,
)
session.commit()
await session.commit()
except OpenClawGatewayError as exc:
record_activity(
session,
@@ -254,21 +250,25 @@ def _notify_agent_on_task_assign(
agent_id=agent.id,
task_id=task.id,
)
session.commit()
await session.commit()
def _notify_lead_on_task_create(
async def _notify_lead_on_task_create(
*,
session: Session,
session: AsyncSession,
board: Board,
task: Task,
) -> None:
lead = session.exec(
select(Agent).where(Agent.board_id == board.id).where(Agent.is_board_lead.is_(True))
lead = (
await session.exec(
select(Agent)
.where(Agent.board_id == board.id)
.where(col(Agent.is_board_lead).is_(True))
)
).first()
if lead is None or not lead.openclaw_session_id:
return
config = _gateway_config(session, board)
config = await _gateway_config(session, board)
if config is None:
return
description = (task.description or "").strip()
@@ -288,12 +288,10 @@ def _notify_lead_on_task_create(
+ "\n\nTake action: triage, assign, or plan next steps."
)
try:
asyncio.run(
_send_lead_task_message(
session_key=lead.openclaw_session_id,
config=config,
message=message,
)
await _send_lead_task_message(
session_key=lead.openclaw_session_id,
config=config,
message=message,
)
record_activity(
session,
@@ -302,7 +300,7 @@ def _notify_lead_on_task_create(
agent_id=lead.id,
task_id=task.id,
)
session.commit()
await session.commit()
except OpenClawGatewayError as exc:
record_activity(
session,
@@ -311,21 +309,25 @@ def _notify_lead_on_task_create(
agent_id=lead.id,
task_id=task.id,
)
session.commit()
await session.commit()
def _notify_lead_on_task_unassigned(
async def _notify_lead_on_task_unassigned(
*,
session: Session,
session: AsyncSession,
board: Board,
task: Task,
) -> None:
lead = session.exec(
select(Agent).where(Agent.board_id == board.id).where(Agent.is_board_lead.is_(True))
lead = (
await session.exec(
select(Agent)
.where(Agent.board_id == board.id)
.where(col(Agent.is_board_lead).is_(True))
)
).first()
if lead is None or not lead.openclaw_session_id:
return
config = _gateway_config(session, board)
config = await _gateway_config(session, board)
if config is None:
return
description = (task.description or "").strip()
@@ -345,12 +347,10 @@ def _notify_lead_on_task_unassigned(
+ "\n\nTake action: assign a new owner or adjust the plan."
)
try:
asyncio.run(
_send_lead_task_message(
session_key=lead.openclaw_session_id,
config=config,
message=message,
)
await _send_lead_task_message(
session_key=lead.openclaw_session_id,
config=config,
message=message,
)
record_activity(
session,
@@ -359,7 +359,7 @@ def _notify_lead_on_task_unassigned(
agent_id=lead.id,
task_id=task.id,
)
session.commit()
await session.commit()
except OpenClawGatewayError as exc:
record_activity(
session,
@@ -368,7 +368,7 @@ def _notify_lead_on_task_unassigned(
agent_id=lead.id,
task_id=task.id,
)
session.commit()
await session.commit()
@router.get("/stream")
@@ -378,16 +378,17 @@ async def stream_tasks(
actor: ActorContext = Depends(require_admin_or_agent),
since: str | None = Query(default=None),
) -> EventSourceResponse:
since_dt = _parse_since(since) or datetime.utcnow()
since_dt = _parse_since(since) or utcnow()
seen_ids: set[UUID] = set()
seen_queue: deque[UUID] = deque()
async def event_generator():
async def event_generator() -> AsyncIterator[dict[str, str]]:
last_seen = since_dt
while True:
if await request.is_disconnected():
break
rows = await run_in_threadpool(_fetch_task_events, board.id, last_seen)
async with async_session_maker() as session:
rows = await _fetch_task_events(session, board.id, last_seen)
for event, task in rows:
if event.id in seen_ids:
continue
@@ -410,13 +411,13 @@ async def stream_tasks(
@router.get("", response_model=list[TaskRead])
def list_tasks(
async def list_tasks(
status_filter: str | None = Query(default=None, alias="status"),
assigned_agent_id: UUID | None = None,
unassigned: bool | None = None,
limit: int | None = Query(default=None, ge=1, le=200),
board: Board = Depends(get_board_or_404),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> list[Task]:
statement = select(Task).where(Task.board_id == board.id)
@@ -435,24 +436,23 @@ def list_tasks(
statement = statement.where(col(Task.assigned_agent_id).is_(None))
if limit is not None:
statement = statement.limit(limit)
return list(session.exec(statement))
return list(await session.exec(statement))
@router.post("", response_model=TaskRead)
def create_task(
async def create_task(
payload: TaskCreate,
board: Board = Depends(get_board_or_404),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(require_admin_auth),
) -> Task:
validate_task_status(payload.status)
task = Task.model_validate(payload)
task.board_id = board.id
if task.created_by_user_id is None and auth.user is not None:
task.created_by_user_id = auth.user.id
session.add(task)
session.commit()
session.refresh(task)
await session.commit()
await session.refresh(task)
record_activity(
session,
@@ -460,12 +460,12 @@ def create_task(
task_id=task.id,
message=f"Task created: {task.title}.",
)
session.commit()
_notify_lead_on_task_create(session=session, board=board, task=task)
await session.commit()
await _notify_lead_on_task_create(session=session, board=board, task=task)
if task.assigned_agent_id:
assigned_agent = session.get(Agent, task.assigned_agent_id)
assigned_agent = await session.get(Agent, task.assigned_agent_id)
if assigned_agent:
_notify_agent_on_task_assign(
await _notify_agent_on_task_assign(
session=session,
board=board,
task=task,
@@ -475,18 +475,16 @@ def create_task(
@router.patch("/{task_id}", response_model=TaskRead)
def update_task(
async def update_task(
payload: TaskUpdate,
task: Task = Depends(get_task_or_404),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> Task:
previous_status = task.status
previous_assigned = task.assigned_agent_id
updates = payload.model_dump(exclude_unset=True)
comment = updates.pop("comment", None)
if comment is not None and not comment.strip():
comment = None
if actor.actor_type == "agent" and actor.agent and actor.agent.is_board_lead:
allowed_fields = {"assigned_agent_id", "status"}
@@ -498,7 +496,7 @@ def update_task(
if "assigned_agent_id" in updates:
assigned_id = updates["assigned_agent_id"]
if assigned_id:
agent = session.get(Agent, assigned_id)
agent = await session.get(Agent, assigned_id)
if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if agent.is_board_lead:
@@ -512,7 +510,6 @@ def update_task(
else:
task.assigned_agent_id = None
if "status" in updates:
validate_task_status(updates["status"])
if task.status != "review":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@@ -526,8 +523,8 @@ def update_task(
if updates["status"] == "inbox":
task.assigned_agent_id = None
task.in_progress_at = None
task.status = updates["status"]
task.updated_at = datetime.utcnow()
task.status = updates["status"]
task.updated_at = utcnow()
session.add(task)
if task.status != previous_status:
event_type = "task.status_changed"
@@ -542,17 +539,17 @@ def update_task(
message=message,
agent_id=actor.agent.id,
)
session.commit()
session.refresh(task)
await session.commit()
await session.refresh(task)
if task.assigned_agent_id and task.assigned_agent_id != previous_assigned:
if actor.actor_type == "agent" and actor.agent and task.assigned_agent_id == actor.agent.id:
return task
assigned_agent = session.get(Agent, task.assigned_agent_id)
assigned_agent = await session.get(Agent, task.assigned_agent_id)
if assigned_agent:
board = session.get(Board, task.board_id) if task.board_id else None
board = await session.get(Board, task.board_id) if task.board_id else None
if board:
_notify_agent_on_task_assign(
await _notify_agent_on_task_assign(
session=session,
board=board,
task=task,
@@ -567,37 +564,35 @@ def update_task(
if not set(updates).issubset(allowed_fields):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if "status" in updates:
validate_task_status(updates["status"])
if updates["status"] == "inbox":
task.assigned_agent_id = None
task.in_progress_at = None
else:
task.assigned_agent_id = actor.agent.id if actor.agent else None
if updates["status"] == "in_progress":
task.in_progress_at = datetime.utcnow()
task.in_progress_at = utcnow()
elif "status" in updates:
validate_task_status(updates["status"])
if updates["status"] == "inbox":
task.assigned_agent_id = None
task.in_progress_at = None
elif updates["status"] == "in_progress":
task.in_progress_at = datetime.utcnow()
task.in_progress_at = utcnow()
if "assigned_agent_id" in updates and updates["assigned_agent_id"]:
agent = session.get(Agent, updates["assigned_agent_id"])
agent = await session.get(Agent, updates["assigned_agent_id"])
if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if agent.board_id and task.board_id and agent.board_id != task.board_id:
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
for key, value in updates.items():
setattr(task, key, value)
task.updated_at = datetime.utcnow()
task.updated_at = utcnow()
if "status" in updates and updates["status"] == "review":
if comment is not None and comment.strip():
if not comment.strip():
raise _comment_validation_error()
else:
if not has_valid_recent_comment(
if not await has_valid_recent_comment(
session,
task,
task.assigned_agent_id,
@@ -606,8 +601,8 @@ def update_task(
raise _comment_validation_error()
session.add(task)
session.commit()
session.refresh(task)
await session.commit()
await session.refresh(task)
if comment is not None and comment.strip():
event = ActivityEvent(
@@ -617,7 +612,7 @@ def update_task(
agent_id=actor.agent.id if actor.actor_type == "agent" and actor.agent else None,
)
session.add(event)
session.commit()
await session.commit()
if "status" in updates and task.status != previous_status:
event_type = "task.status_changed"
@@ -632,12 +627,12 @@ def update_task(
message=message,
agent_id=actor.agent.id if actor.actor_type == "agent" and actor.agent else None,
)
session.commit()
await session.commit()
if task.status == "inbox" and task.assigned_agent_id is None:
if previous_status != "inbox" or previous_assigned is not None:
board = session.get(Board, task.board_id) if task.board_id else None
board = await session.get(Board, task.board_id) if task.board_id else None
if board:
_notify_lead_on_task_unassigned(
await _notify_lead_on_task_unassigned(
session=session,
board=board,
task=task,
@@ -645,11 +640,11 @@ def update_task(
if task.assigned_agent_id and task.assigned_agent_id != previous_assigned:
if actor.actor_type == "agent" and actor.agent and task.assigned_agent_id == actor.agent.id:
return task
assigned_agent = session.get(Agent, task.assigned_agent_id)
assigned_agent = await session.get(Agent, task.assigned_agent_id)
if assigned_agent:
board = session.get(Board, task.board_id) if task.board_id else None
board = await session.get(Board, task.board_id) if task.board_id else None
if board:
_notify_agent_on_task_assign(
await _notify_agent_on_task_assign(
session=session,
board=board,
task=task,
@@ -658,23 +653,23 @@ def update_task(
return task
@router.delete("/{task_id}")
def delete_task(
session: Session = Depends(get_session),
@router.delete("/{task_id}", response_model=OkResponse)
async def delete_task(
session: AsyncSession = Depends(get_session),
task: Task = Depends(get_task_or_404),
auth: AuthContext = Depends(require_admin_auth),
) -> dict[str, bool]:
session.execute(delete(ActivityEvent).where(col(ActivityEvent.task_id) == task.id))
session.execute(delete(TaskFingerprint).where(col(TaskFingerprint.task_id) == task.id))
session.delete(task)
session.commit()
return {"ok": True}
) -> OkResponse:
await session.execute(delete(ActivityEvent).where(col(ActivityEvent.task_id) == task.id))
await session.execute(delete(TaskFingerprint).where(col(TaskFingerprint.task_id) == task.id))
await session.delete(task)
await session.commit()
return OkResponse()
@router.get("/{task_id}/comments", response_model=list[TaskCommentRead])
def list_task_comments(
async def list_task_comments(
task: Task = Depends(get_task_or_404),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> list[ActivityEvent]:
if actor.actor_type == "agent" and actor.agent:
@@ -686,19 +681,19 @@ def list_task_comments(
.where(col(ActivityEvent.event_type) == "task.comment")
.order_by(asc(col(ActivityEvent.created_at)))
)
return list(session.exec(statement))
return list(await session.exec(statement))
@router.post("/{task_id}/comments", response_model=TaskCommentRead)
def create_task_comment(
async def create_task_comment(
payload: TaskCommentCreate,
task: Task = Depends(get_task_or_404),
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> ActivityEvent:
if actor.actor_type == "agent" and actor.agent:
if actor.agent.is_board_lead and task.status != "review":
if not _lead_was_mentioned(session, task, actor.agent) and not _lead_created_task(
if not await _lead_was_mentioned(session, task, actor.agent) and not _lead_created_task(
task, actor.agent
):
raise HTTPException(
@@ -709,8 +704,6 @@ def create_task_comment(
)
if actor.agent.board_id and task.board_id and actor.agent.board_id != task.board_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if not payload.message.strip():
raise _comment_validation_error()
event = ActivityEvent(
event_type="task.comment",
message=payload.message,
@@ -718,24 +711,24 @@ def create_task_comment(
agent_id=actor.agent.id if actor.actor_type == "agent" and actor.agent else None,
)
session.add(event)
session.commit()
session.refresh(event)
await session.commit()
await session.refresh(event)
mention_names = _extract_mentions(payload.message)
targets: dict[UUID, Agent] = {}
if mention_names and task.board_id:
statement = select(Agent).where(col(Agent.board_id) == task.board_id)
for agent in session.exec(statement):
for agent in await session.exec(statement):
if _matches_mention(agent, mention_names):
targets[agent.id] = agent
if not mention_names and task.assigned_agent_id:
assigned_agent = session.get(Agent, task.assigned_agent_id)
assigned_agent = await session.get(Agent, task.assigned_agent_id)
if assigned_agent:
targets[assigned_agent.id] = assigned_agent
if actor.actor_type == "agent" and actor.agent:
targets.pop(actor.agent.id, None)
if targets:
board = session.get(Board, task.board_id) if task.board_id else None
config = _gateway_config(session, board) if board else None
board = await session.get(Board, task.board_id) if task.board_id else None
config = await _gateway_config(session, board) if board else None
if board and config:
snippet = payload.message.strip()
if len(snippet) > 500:
@@ -762,13 +755,11 @@ def create_task_comment(
"If you are mentioned but not assigned, reply in the task thread but do not change task status."
)
try:
asyncio.run(
_send_agent_task_message(
session_key=agent.openclaw_session_id,
config=config,
agent_name=agent.name,
message=message,
)
await _send_agent_task_message(
session_key=agent.openclaw_session_id,
config=config,
agent_name=agent.name,
message=message,
)
except OpenClawGatewayError:
pass

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel import Session
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.auth import AuthContext, get_auth_context
from app.db.session import get_session
@@ -21,7 +21,7 @@ async def get_me(auth: AuthContext = Depends(get_auth_context)) -> UserRead:
@router.patch("/me", response_model=UserRead)
async def update_me(
payload: UserUpdate,
session: Session = Depends(get_session),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
) -> UserRead:
if auth.actor_type != "user" or auth.user is None:
@@ -31,6 +31,6 @@ async def update_me(
for key, value in updates.items():
setattr(user, key, value)
session.add(user)
session.commit()
session.refresh(user)
await session.commit()
await session.refresh(user)
return UserRead.model_validate(user)