diff --git a/backend/app/api/activity.py b/backend/app/api/activity.py index 8b8f0993..ca59fc40 100644 --- a/backend/app/api/activity.py +++ b/backend/app/api/activity.py @@ -24,7 +24,11 @@ from app.models.boards import Board from app.models.tasks import Task from app.schemas.activity_events import ActivityEventRead, ActivityTaskCommentFeedItemRead from app.schemas.pagination import DefaultLimitOffsetPage -from app.services.organizations import get_active_membership, list_accessible_board_ids +from app.services.organizations import ( + OrganizationContext, + get_active_membership, + list_accessible_board_ids, +) router = APIRouter(prefix="/activity", tags=["activity"]) @@ -134,7 +138,7 @@ async def list_activity( async def list_task_comment_feed( board_id: UUID | None = Query(default=None), session: AsyncSession = Depends(get_session), - ctx=Depends(require_org_member), + ctx: OrganizationContext = Depends(require_org_member), ) -> DefaultLimitOffsetPage[ActivityTaskCommentFeedItemRead]: statement = ( select(ActivityEvent, Task, Board, Agent) @@ -168,7 +172,7 @@ async def stream_task_comment_feed( board_id: UUID | None = Query(default=None), since: str | None = Query(default=None), session: AsyncSession = Depends(get_session), - ctx=Depends(require_org_member), + ctx: OrganizationContext = Depends(require_org_member), ) -> EventSourceResponse: since_dt = _parse_since(since) or utcnow() board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False) diff --git a/backend/app/api/agents.py b/backend/app/api/agents.py index cf06d1da..45e9d0ea 100644 --- a/backend/app/api/agents.py +++ b/backend/app/api/agents.py @@ -10,6 +10,7 @@ from uuid import UUID, uuid4 from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from sqlalchemy import asc, or_, update +from sqlalchemy.sql.elements import ColumnElement from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession from sse_starlette.sse import EventSourceResponse @@ -245,7 +246,7 @@ async def _require_agent_access( session: AsyncSession, *, agent: Agent, - ctx, + ctx: OrganizationContext, write: bool, ) -> None: if agent.board_id is None: @@ -302,7 +303,7 @@ async def list_agents( board_id: UUID | None = Query(default=None), gateway_id: UUID | None = Query(default=None), session: AsyncSession = Depends(get_session), - ctx=Depends(require_org_admin), + ctx: OrganizationContext = Depends(require_org_admin), ) -> DefaultLimitOffsetPage[AgentRead]: main_session_keys = await _get_gateway_main_session_keys(session) board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False) @@ -311,7 +312,7 @@ async def list_agents( if not board_ids: statement = select(Agent).where(col(Agent.id).is_(None)) else: - base_filter = col(Agent.board_id).in_(board_ids) + base_filter: ColumnElement[bool] = col(Agent.board_id).in_(board_ids) if is_org_admin(ctx.member): gateway_keys = select(Gateway.main_session_key).where( col(Gateway.organization_id) == ctx.organization.id @@ -342,7 +343,7 @@ async def stream_agents( board_id: UUID | None = Query(default=None), since: str | None = Query(default=None), session: AsyncSession = Depends(get_session), - ctx=Depends(require_org_admin), + ctx: OrganizationContext = Depends(require_org_admin), ) -> EventSourceResponse: since_dt = _parse_since(since) or utcnow() last_seen = since_dt @@ -528,7 +529,7 @@ async def create_agent( async def get_agent( agent_id: str, session: AsyncSession = Depends(get_session), - ctx=Depends(require_org_admin), + ctx: OrganizationContext = Depends(require_org_admin), ) -> AgentRead: agent = await session.get(Agent, agent_id) if agent is None: @@ -545,7 +546,7 @@ async def update_agent( force: bool = False, session: AsyncSession = Depends(get_session), auth: AuthContext = Depends(get_auth_context), - ctx=Depends(require_org_admin), + ctx: OrganizationContext = Depends(require_org_admin), ) -> AgentRead: agent = await session.get(Agent, agent_id) if agent is None: @@ -841,57 +842,62 @@ async def heartbeat_or_create_agent( except Exception as exc: # pragma: no cover - unexpected provisioning errors _record_instruction_failure(session, agent, str(exc), "provision") await session.commit() - elif actor.actor_type == "user": - ctx = await _require_user_context(session, actor.user) - await _require_agent_access(session, agent=agent, ctx=ctx, write=True) - elif actor.actor_type == "agent" and actor.agent and actor.agent.id != agent.id: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - elif agent.agent_token_hash is None and actor.actor_type == "user": - raw_token = generate_agent_token() - agent.agent_token_hash = hash_agent_token(raw_token) - if agent.heartbeat_config is None: - agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy() - agent.provision_requested_at = utcnow() - agent.provision_action = "provision" - session.add(agent) - await session.commit() - await session.refresh(agent) - try: - board = await _require_board( - session, - str(agent.board_id) if agent.board_id else None, - user=actor.user if actor.actor_type == "user" else None, - write=actor.actor_type == "user", - ) - gateway, client_config = await _require_gateway(session, board) - await provision_agent(agent, board, gateway, raw_token, actor.user, action="provision") - await _send_wakeup_message(agent, client_config, verb="provisioned") - agent.provision_confirm_token_hash = None - agent.provision_requested_at = None - agent.provision_action = None - agent.updated_at = utcnow() - session.add(agent) - await session.commit() - record_activity( - session, - event_type="agent.provision", - message=f"Provisioned directly for {agent.name}.", - agent_id=agent.id, - ) - record_activity( - session, - event_type="agent.wakeup.sent", - message=f"Wakeup message sent to {agent.name}.", - agent_id=agent.id, - ) - await session.commit() - except OpenClawGatewayError as exc: - _record_instruction_failure(session, agent, str(exc), "provision") - await session.commit() - except Exception as exc: # pragma: no cover - unexpected provisioning errors - _record_instruction_failure(session, agent, str(exc), "provision") - await session.commit() - elif not agent.openclaw_session_id: + else: + if actor.actor_type == "user": + ctx = await _require_user_context(session, actor.user) + await _require_agent_access(session, agent=agent, ctx=ctx, write=True) + + if agent.agent_token_hash is None: + raw_token = generate_agent_token() + agent.agent_token_hash = hash_agent_token(raw_token) + if agent.heartbeat_config is None: + agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy() + agent.provision_requested_at = utcnow() + agent.provision_action = "provision" + session.add(agent) + await session.commit() + await session.refresh(agent) + try: + board = await _require_board( + session, + str(agent.board_id) if agent.board_id else None, + user=actor.user, + write=True, + ) + gateway, client_config = await _require_gateway(session, board) + await provision_agent( + agent, board, gateway, raw_token, actor.user, action="provision" + ) + await _send_wakeup_message(agent, client_config, verb="provisioned") + agent.provision_confirm_token_hash = None + agent.provision_requested_at = None + agent.provision_action = None + agent.updated_at = utcnow() + session.add(agent) + await session.commit() + record_activity( + session, + event_type="agent.provision", + message=f"Provisioned directly for {agent.name}.", + agent_id=agent.id, + ) + record_activity( + session, + event_type="agent.wakeup.sent", + message=f"Wakeup message sent to {agent.name}.", + agent_id=agent.id, + ) + await session.commit() + except OpenClawGatewayError as exc: + _record_instruction_failure(session, agent, str(exc), "provision") + await session.commit() + except Exception as exc: # pragma: no cover - unexpected provisioning errors + _record_instruction_failure(session, agent, str(exc), "provision") + await session.commit() + elif actor.actor_type == "agent" and actor.agent and actor.agent.id != agent.id: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + + if not agent.openclaw_session_id: board = await _require_board( session, str(agent.board_id) if agent.board_id else None, @@ -934,7 +940,7 @@ async def heartbeat_or_create_agent( async def delete_agent( agent_id: str, session: AsyncSession = Depends(get_session), - ctx=Depends(require_org_admin), + ctx: OrganizationContext = Depends(require_org_admin), ) -> OkResponse: agent = await session.get(Agent, agent_id) if agent is None: