diff --git a/backend/app/api/agent.py b/backend/app/api/agent.py index 11a2b0ed..5c947fb6 100644 --- a/backend/app/api/agent.py +++ b/backend/app/api/agent.py @@ -63,7 +63,11 @@ from app.schemas.tasks import ( TaskUpdate, ) from app.services.activity_log import record_activity -from app.services.board_leads import ensure_board_lead_agent +from app.services.board_leads import ( + LeadAgentOptions, + LeadAgentRequest, + ensure_board_lead_agent, +) from app.services.task_dependencies import ( blocked_by_dependency_ids, dependency_status_by_id, @@ -113,6 +117,29 @@ class SoulUpdateRequest(SQLModel): reason: str | None = None +class AgentTaskListFilters(SQLModel): + """Query filters for board task listing in agent routes.""" + + status_filter: str | None = None + assigned_agent_id: UUID | None = None + unassigned: bool | None = None + + +def _task_list_filters( + status_filter: str | None = TASK_STATUS_QUERY, + assigned_agent_id: UUID | None = None, + unassigned: bool | None = None, +) -> AgentTaskListFilters: + return AgentTaskListFilters( + status_filter=status_filter, + assigned_agent_id=assigned_agent_id, + unassigned=unassigned, + ) + + +TASK_LIST_FILTERS_DEP = Depends(_task_list_filters) + + def _actor(agent_ctx: AgentAuthContext) -> ActorContext: return ActorContext(actor_type="agent", agent=agent_ctx.agent) @@ -217,19 +244,16 @@ async def list_agents( statement = statement.where(Agent.board_id == agent_ctx.agent.board_id) elif board_id: statement = statement.where(Agent.board_id == board_id) - get_gateway_main_session_keys = ( - agents_api._get_gateway_main_session_keys # noqa: SLF001 - ) - to_agent_read = agents_api._to_agent_read # noqa: SLF001 - with_computed_status = agents_api._with_computed_status # noqa: SLF001 - - main_session_keys = await get_gateway_main_session_keys(session) + main_session_keys = await agents_api.get_gateway_main_session_keys(session) statement = statement.order_by(col(Agent.created_at).desc()) def _transform(items: Sequence[Any]) -> Sequence[Any]: agents = cast(Sequence[Agent], items) return [ - to_agent_read(with_computed_status(agent), main_session_keys) + agents_api.to_agent_read( + agents_api.with_computed_status(agent), + main_session_keys, + ) for agent in agents ] @@ -237,10 +261,8 @@ async def list_agents( @router.get("/boards/{board_id}/tasks", response_model=DefaultLimitOffsetPage[TaskRead]) -async def list_tasks( # noqa: PLR0913 - status_filter: str | None = TASK_STATUS_QUERY, - assigned_agent_id: UUID | None = None, - unassigned: bool | None = None, +async def list_tasks( + filters: AgentTaskListFilters = TASK_LIST_FILTERS_DEP, board: Board = BOARD_DEP, session: AsyncSession = SESSION_DEP, agent_ctx: AgentAuthContext = AGENT_CTX_DEP, @@ -248,9 +270,9 @@ async def list_tasks( # noqa: PLR0913 """List tasks on a board with optional status and assignment filters.""" _guard_board_access(agent_ctx, board) return await tasks_api.list_tasks( - status_filter=status_filter, - assigned_agent_id=assigned_agent_id, - unassigned=unassigned, + status_filter=filters.status_filter, + assigned_agent_id=filters.assigned_agent_id, + unassigned=filters.unassigned, board=board, session=session, actor=_actor(agent_ctx), @@ -336,10 +358,7 @@ async def create_task( session, ) if assigned_agent: - notify_agent_on_task_assign = ( - tasks_api._notify_agent_on_task_assign # noqa: SLF001 - ) - await notify_agent_on_task_assign( + await tasks_api.notify_agent_on_task_assign( session=session, board=board, task=task, @@ -821,11 +840,13 @@ async def message_gateway_board_lead( board = await _require_gateway_board(session, gateway=gateway, board_id=board_id) lead, lead_created = await ensure_board_lead_agent( session, - board=board, - gateway=gateway, - config=config, - user=None, - action="provision", + request=LeadAgentRequest( + board=board, + gateway=gateway, + config=config, + user=None, + options=LeadAgentOptions(action="provision"), + ), ) if not lead.openclaw_session_id: raise HTTPException( @@ -932,11 +953,13 @@ async def broadcast_gateway_lead_message( try: lead, _lead_created = await ensure_board_lead_agent( session, - board=board, - gateway=gateway, - config=config, - user=None, - action="provision", + request=LeadAgentRequest( + board=board, + gateway=gateway, + config=config, + user=None, + options=LeadAgentOptions(action="provision"), + ), ) lead_session_key = _require_lead_session_key(lead) message = ( diff --git a/backend/app/api/agents.py b/backend/app/api/agents.py index ff2e21c6..64560e5c 100644 --- a/backend/app/api/agents.py +++ b/backend/app/api/agents.py @@ -6,6 +6,7 @@ import asyncio import json import re from collections.abc import AsyncIterator, Sequence +from dataclasses import dataclass from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, cast from uuid import UUID, uuid4 @@ -46,6 +47,9 @@ from app.schemas.pagination import DefaultLimitOffsetPage from app.services.activity_log import record_activity from app.services.agent_provisioning import ( DEFAULT_HEARTBEAT_CONFIG, + AgentProvisionRequest, + MainAgentProvisionRequest, + ProvisionOptions, cleanup_agent, provision_agent, provision_main_agent, @@ -78,6 +82,25 @@ ACTOR_DEP = Depends(require_admin_or_agent) AUTH_DEP = Depends(get_auth_context) +@dataclass(frozen=True, slots=True) +class _AgentUpdateParams: + force: bool + auth: AuthContext + ctx: OrganizationContext + + +def _agent_update_params( + *, + force: bool = False, + auth: AuthContext = AUTH_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, +) -> _AgentUpdateParams: + return _AgentUpdateParams(force=force, auth=auth, ctx=ctx) + + +AGENT_UPDATE_PARAMS_DEP = Depends(_agent_update_params) + + def _parse_since(value: str | None) -> datetime | None: if not value: return None @@ -199,6 +222,16 @@ def _to_agent_read(agent: Agent, main_session_keys: set[str]) -> AgentRead: ) +async def get_gateway_main_session_keys(session: AsyncSession) -> set[str]: + """Return gateway main-session keys used to compute `is_gateway_main`.""" + return await _get_gateway_main_session_keys(session) + + +def to_agent_read(agent: Agent, main_session_keys: set[str]) -> AgentRead: + """Convert an `Agent` model into its API read representation.""" + return _to_agent_read(agent, main_session_keys) + + async def _find_gateway_for_main_session( session: AsyncSession, session_key: str | None, ) -> Gateway | None: @@ -231,6 +264,11 @@ def _with_computed_status(agent: Agent) -> Agent: return agent +def with_computed_status(agent: Agent) -> Agent: + """Apply transient online/offline status derivation to an agent model.""" + return _with_computed_status(agent) + + def _serialize_agent(agent: Agent, main_session_keys: set[str]) -> dict[str, object]: return _to_agent_read(_with_computed_status(agent), main_session_keys).model_dump( mode="json", @@ -315,6 +353,577 @@ def _record_instruction_failure( ) +async def _coerce_agent_create_payload( + session: AsyncSession, + payload: AgentCreate, + actor: ActorContext, +) -> AgentCreate: + if actor.actor_type == "user": + ctx = await _require_user_context(session, actor.user) + if not is_org_admin(ctx.member): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + return payload + + if actor.actor_type == "agent": + if not actor.agent or not actor.agent.is_board_lead: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only board leads can create agents", + ) + if not actor.agent.board_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Board lead must be assigned to a board", + ) + if payload.board_id and payload.board_id != actor.agent.board_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Board leads can only create agents in their own board", + ) + return AgentCreate(**{**payload.model_dump(), "board_id": actor.agent.board_id}) + + return payload + + +async def _ensure_unique_agent_name( + session: AsyncSession, + *, + board: Board, + gateway: Gateway, + requested_name: str, +) -> None: + if not requested_name: + return + + existing = ( + await session.exec( + select(Agent) + .where(Agent.board_id == board.id) + .where(col(Agent.name).ilike(requested_name)), + ) + ).first() + if existing: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="An agent with this name already exists on this board.", + ) + + existing_gateway = ( + await session.exec( + select(Agent) + .join(Board, col(Agent.board_id) == col(Board.id)) + .where(col(Board.gateway_id) == gateway.id) + .where(col(Agent.name).ilike(requested_name)), + ) + ).first() + if existing_gateway: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=( + "An agent with this name already exists in this gateway " + "workspace." + ), + ) + + desired_session_key = _build_session_key(requested_name) + existing_session_key = ( + await session.exec( + select(Agent) + .join(Board, col(Agent.board_id) == col(Board.id)) + .where(col(Board.gateway_id) == gateway.id) + .where(col(Agent.openclaw_session_id) == desired_session_key), + ) + ).first() + if existing_session_key: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=( + "This agent name would collide with an existing workspace " + "session key. Pick a different name." + ), + ) + + +async def _persist_new_agent( + session: AsyncSession, + *, + data: dict[str, Any], + client_config: GatewayClientConfig, +) -> tuple[Agent, str, str | None]: + agent = Agent.model_validate(data) + agent.status = "provisioning" + raw_token = generate_agent_token() + agent.agent_token_hash = hash_agent_token(raw_token) + if agent.heartbeat_config is None: + agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy() + agent.provision_requested_at = utcnow() + agent.provision_action = "provision" + session_key, session_error = await _ensure_gateway_session( + agent.name, + client_config, + ) + agent.openclaw_session_id = session_key + session.add(agent) + await session.commit() + await session.refresh(agent) + return agent, raw_token, session_error + + +async def _record_session_creation( + session: AsyncSession, + *, + agent: Agent, + session_error: str | None, +) -> None: + if session_error: + record_activity( + session, + event_type="agent.session.failed", + message=f"Session sync failed for {agent.name}: {session_error}", + agent_id=agent.id, + ) + else: + record_activity( + session, + event_type="agent.session.created", + message=f"Session created for {agent.name}.", + agent_id=agent.id, + ) + await session.commit() + + +async def _provision_new_agent( + session: AsyncSession, + *, + agent: Agent, + request: AgentProvisionRequest, + client_config: GatewayClientConfig, +) -> None: + try: + await provision_agent(agent, request) + await _send_wakeup_message(agent, client_config, verb="provisioned") + agent.provision_confirm_token_hash = None + agent.provision_requested_at = None + agent.provision_action = None + agent.updated_at = utcnow() + session.add(agent) + await session.commit() + record_activity( + session, + event_type="agent.provision", + message=f"Provisioned directly for {agent.name}.", + agent_id=agent.id, + ) + record_activity( + session, + event_type="agent.wakeup.sent", + message=f"Wakeup message sent to {agent.name}.", + agent_id=agent.id, + ) + await session.commit() + except OpenClawGatewayError as exc: + _record_instruction_failure(session, agent, str(exc), "provision") + await session.commit() + except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover + _record_instruction_failure(session, agent, str(exc), "provision") + await session.commit() + + +@dataclass(frozen=True, slots=True) +class _AgentUpdateProvisionTarget: + is_main_agent: bool + board: Board | None + gateway: Gateway + client_config: GatewayClientConfig + + +@dataclass(frozen=True, slots=True) +class _AgentUpdateProvisionRequest: + target: _AgentUpdateProvisionTarget + raw_token: str + user: User | None + force_bootstrap: bool + + +async def _validate_agent_update_inputs( + session: AsyncSession, + *, + ctx: OrganizationContext, + updates: dict[str, Any], + make_main: bool | None, +) -> None: + if make_main is True and not is_org_admin(ctx.member): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + if "status" in updates: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="status is controlled by agent heartbeat", + ) + if "board_id" in updates and updates["board_id"] is not None: + new_board = await _require_board(session, updates["board_id"]) + if new_board.organization_id != ctx.organization.id: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + if not await has_board_access( + session, + member=ctx.member, + board=new_board, + write=True, + ): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + + +async def _apply_agent_update_mutations( + session: AsyncSession, + *, + agent: Agent, + updates: dict[str, Any], + make_main: bool | None, +) -> tuple[Gateway | None, Gateway | None]: + main_gateway = await _find_gateway_for_main_session( + session, + agent.openclaw_session_id, + ) + gateway_for_main: Gateway | None = None + + if make_main is True: + board_source = updates.get("board_id") or agent.board_id + board_for_main = await _require_board(session, board_source) + gateway_for_main, _ = await _require_gateway(session, board_for_main) + updates["board_id"] = None + agent.is_board_lead = False + agent.openclaw_session_id = gateway_for_main.main_session_key + main_gateway = gateway_for_main + elif make_main is False: + agent.openclaw_session_id = None + + if make_main is not True and "board_id" in updates: + await _require_board(session, updates["board_id"]) + for key, value in updates.items(): + setattr(agent, key, value) + + if make_main is None and main_gateway is not None: + agent.board_id = None + agent.is_board_lead = False + agent.updated_at = utcnow() + if agent.heartbeat_config is None: + agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy() + session.add(agent) + await session.commit() + await session.refresh(agent) + return main_gateway, gateway_for_main + + +async def _resolve_agent_update_target( + session: AsyncSession, + *, + agent: Agent, + make_main: bool | None, + main_gateway: Gateway | None, + gateway_for_main: Gateway | None, +) -> _AgentUpdateProvisionTarget: + if make_main is True: + if gateway_for_main is None: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Main agent requires a gateway main_session_key", + ) + if not gateway_for_main.main_session_key: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Gateway main_session_key is required", + ) + return _AgentUpdateProvisionTarget( + is_main_agent=True, + board=None, + gateway=gateway_for_main, + client_config=_gateway_client_config(gateway_for_main), + ) + + if make_main is None and agent.board_id is None and main_gateway is not None: + if not main_gateway.main_session_key: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Gateway main_session_key is required", + ) + return _AgentUpdateProvisionTarget( + is_main_agent=True, + board=None, + gateway=main_gateway, + client_config=_gateway_client_config(main_gateway), + ) + + if agent.board_id is None: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="board_id is required for non-main agents", + ) + board = await _require_board(session, agent.board_id) + gateway, client_config = await _require_gateway(session, board) + return _AgentUpdateProvisionTarget( + is_main_agent=False, + board=board, + gateway=gateway, + client_config=client_config, + ) + + +async def _ensure_agent_update_session( + session: AsyncSession, + *, + agent: Agent, + client_config: GatewayClientConfig, +) -> None: + session_key = agent.openclaw_session_id or _build_session_key(agent.name) + try: + await ensure_session(session_key, config=client_config, label=agent.name) + if not agent.openclaw_session_id: + agent.openclaw_session_id = session_key + session.add(agent) + await session.commit() + await session.refresh(agent) + except OpenClawGatewayError as exc: + _record_instruction_failure(session, agent, str(exc), "update") + await session.commit() + + +def _mark_agent_update_pending(agent: Agent) -> str: + raw_token = generate_agent_token() + agent.agent_token_hash = hash_agent_token(raw_token) + agent.provision_requested_at = utcnow() + agent.provision_action = "update" + agent.status = "updating" + return raw_token + + +async def _provision_updated_agent( + session: AsyncSession, + *, + agent: Agent, + request: _AgentUpdateProvisionRequest, +) -> None: + try: + if request.target.is_main_agent: + await provision_main_agent( + agent, + MainAgentProvisionRequest( + gateway=request.target.gateway, + auth_token=request.raw_token, + user=request.user, + options=ProvisionOptions( + action="update", + force_bootstrap=request.force_bootstrap, + reset_session=True, + ), + ), + ) + else: + if request.target.board is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="board is required for non-main agent provisioning", + ) + await provision_agent( + agent, + AgentProvisionRequest( + board=request.target.board, + gateway=request.target.gateway, + auth_token=request.raw_token, + user=request.user, + options=ProvisionOptions( + action="update", + force_bootstrap=request.force_bootstrap, + reset_session=True, + ), + ), + ) + await _send_wakeup_message( + agent, + request.target.client_config, + verb="updated", + ) + agent.provision_confirm_token_hash = None + agent.provision_requested_at = None + agent.provision_action = None + agent.status = "online" + agent.updated_at = utcnow() + session.add(agent) + await session.commit() + record_activity( + session, + event_type="agent.update.direct", + message=f"Updated directly for {agent.name}.", + agent_id=agent.id, + ) + record_activity( + session, + event_type="agent.wakeup.sent", + message=f"Wakeup message sent to {agent.name}.", + agent_id=agent.id, + ) + await session.commit() + except OpenClawGatewayError as exc: + _record_instruction_failure(session, agent, str(exc), "update") + await session.commit() + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"Gateway update failed: {exc}", + ) from exc + except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover + _record_instruction_failure(session, agent, str(exc), "update") + await session.commit() + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Unexpected error updating agent provisioning.", + ) from exc + + +def _heartbeat_lookup_statement(payload: AgentHeartbeatCreate) -> object: + statement = Agent.objects.filter_by(name=payload.name).statement + if payload.board_id is not None: + statement = statement.where(Agent.board_id == payload.board_id) + return statement + + +async def _create_agent_from_heartbeat( + session: AsyncSession, + *, + payload: AgentHeartbeatCreate, + actor: ActorContext, +) -> Agent: + if actor.actor_type == "agent": + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + if actor.actor_type == "user": + ctx = await _require_user_context(session, actor.user) + if not is_org_admin(ctx.member): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + + board = await _require_board( + session, + payload.board_id, + user=actor.user, + write=True, + ) + gateway, client_config = await _require_gateway(session, board) + data: dict[str, Any] = { + "name": payload.name, + "board_id": board.id, + "heartbeat_config": DEFAULT_HEARTBEAT_CONFIG.copy(), + } + agent, raw_token, session_error = await _persist_new_agent( + session, + data=data, + client_config=client_config, + ) + await _record_session_creation( + session, + agent=agent, + session_error=session_error, + ) + await _provision_new_agent( + session, + agent=agent, + request=AgentProvisionRequest( + board=board, + gateway=gateway, + auth_token=raw_token, + user=actor.user, + options=ProvisionOptions(action="provision"), + ), + client_config=client_config, + ) + return agent + + +async def _handle_existing_user_heartbeat_agent( + session: AsyncSession, + *, + agent: Agent, + user: User | None, +) -> None: + ctx = await _require_user_context(session, user) + await _require_agent_access(session, agent=agent, ctx=ctx, write=True) + + if agent.agent_token_hash is not None: + return + + raw_token = generate_agent_token() + agent.agent_token_hash = hash_agent_token(raw_token) + if agent.heartbeat_config is None: + agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy() + agent.provision_requested_at = utcnow() + agent.provision_action = "provision" + session.add(agent) + await session.commit() + await session.refresh(agent) + board = await _require_board( + session, + str(agent.board_id) if agent.board_id else None, + user=user, + write=True, + ) + gateway, client_config = await _require_gateway(session, board) + await _provision_new_agent( + session, + agent=agent, + request=AgentProvisionRequest( + board=board, + gateway=gateway, + auth_token=raw_token, + user=user, + options=ProvisionOptions(action="provision"), + ), + client_config=client_config, + ) + + +async def _ensure_heartbeat_session_key( + session: AsyncSession, + *, + agent: Agent, + actor: ActorContext, +) -> None: + if agent.openclaw_session_id: + return + board = await _require_board( + session, + str(agent.board_id) if agent.board_id else None, + user=actor.user if actor.actor_type == "user" else None, + write=actor.actor_type == "user", + ) + _, client_config = await _require_gateway(session, board) + session_key, session_error = await _ensure_gateway_session( + agent.name, + client_config, + ) + agent.openclaw_session_id = session_key + session.add(agent) + await _record_session_creation( + session, + agent=agent, + session_error=session_error, + ) + + +async def _commit_heartbeat( + session: AsyncSession, + *, + agent: Agent, + status_value: str | None, +) -> AgentRead: + if status_value: + agent.status = status_value + elif agent.status == "provisioning": + agent.status = "online" + agent.last_seen_at = utcnow() + agent.updated_at = utcnow() + _record_heartbeat(session, agent) + session.add(agent) + await session.commit() + await session.refresh(agent) + main_session_keys = await _get_gateway_main_session_keys(session) + return _to_agent_read(_with_computed_status(agent), main_session_keys) + + async def _send_wakeup_message( agent: Agent, config: GatewayClientConfig, verb: str = "provisioned", ) -> None: @@ -422,35 +1031,13 @@ async def stream_agents( @router.post("", response_model=AgentRead) -async def create_agent( # noqa: C901, PLR0912, PLR0915 +async def create_agent( payload: AgentCreate, session: AsyncSession = SESSION_DEP, actor: ActorContext = ACTOR_DEP, ) -> AgentRead: """Create and provision an agent.""" - if actor.actor_type == "user": - ctx = await _require_user_context(session, actor.user) - if not is_org_admin(ctx.member): - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - if actor.actor_type == "agent": - if not actor.agent or not actor.agent.is_board_lead: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Only board leads can create agents", - ) - if not actor.agent.board_id: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Board lead must be assigned to a board", - ) - if payload.board_id and payload.board_id != actor.agent.board_id: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Board leads can only create agents in their own board", - ) - payload = AgentCreate( - **{**payload.model_dump(), "board_id": actor.agent.board_id}, - ) + payload = await _coerce_agent_create_payload(session, payload, actor) board = await _require_board( session, @@ -461,119 +1048,35 @@ async def create_agent( # noqa: C901, PLR0912, PLR0915 gateway, client_config = await _require_gateway(session, board) data = payload.model_dump() requested_name = (data.get("name") or "").strip() - if requested_name: - existing = ( - await session.exec( - select(Agent) - .where(Agent.board_id == board.id) - .where(col(Agent.name).ilike(requested_name)), - ) - ).first() - if existing: - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail="An agent with this name already exists on this board.", - ) - # Prevent session/workspace collisions inside the gateway workspace. - # Agents on different boards can still share one gateway root. - existing_gateway = ( - await session.exec( - select(Agent) - .join(Board, col(Agent.board_id) == col(Board.id)) - .where(col(Board.gateway_id) == gateway.id) - .where(col(Agent.name).ilike(requested_name)), - ) - ).first() - if existing_gateway: - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail=( - "An agent with this name already exists in this gateway " - "workspace." - ), - ) - desired_session_key = _build_session_key(requested_name) - existing_session_key = ( - await session.exec( - select(Agent) - .join(Board, col(Agent.board_id) == col(Board.id)) - .where(col(Board.gateway_id) == gateway.id) - .where(col(Agent.openclaw_session_id) == desired_session_key), - ) - ).first() - if existing_session_key: - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail=( - "This agent name would collide with an existing workspace " - "session key. Pick a different name." - ), - ) - agent = Agent.model_validate(data) - agent.status = "provisioning" - raw_token = generate_agent_token() - agent.agent_token_hash = hash_agent_token(raw_token) - if agent.heartbeat_config is None: - agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy() - agent.provision_requested_at = utcnow() - agent.provision_action = "provision" - session_key, session_error = await _ensure_gateway_session( - agent.name, client_config, + await _ensure_unique_agent_name( + session, + board=board, + gateway=gateway, + requested_name=requested_name, + ) + agent, raw_token, session_error = await _persist_new_agent( + session, + data=data, + client_config=client_config, + ) + await _record_session_creation( + session, + agent=agent, + session_error=session_error, + ) + provision_request = AgentProvisionRequest( + board=board, + gateway=gateway, + auth_token=raw_token, + user=actor.user if actor.actor_type == "user" else None, + options=ProvisionOptions(action="provision"), + ) + await _provision_new_agent( + session, + agent=agent, + request=provision_request, + client_config=client_config, ) - agent.openclaw_session_id = session_key - session.add(agent) - await session.commit() - await session.refresh(agent) - if session_error: - record_activity( - session, - event_type="agent.session.failed", - message=f"Session sync failed for {agent.name}: {session_error}", - agent_id=agent.id, - ) - else: - record_activity( - session, - event_type="agent.session.created", - message=f"Session created for {agent.name}.", - agent_id=agent.id, - ) - await session.commit() - try: - await provision_agent( - agent, - board, - gateway, - raw_token, - actor.user if actor.actor_type == "user" else None, - action="provision", - ) - await _send_wakeup_message(agent, client_config, verb="provisioned") - agent.provision_confirm_token_hash = None - agent.provision_requested_at = None - agent.provision_action = None - agent.updated_at = utcnow() - session.add(agent) - await session.commit() - record_activity( - session, - event_type="agent.provision", - message=f"Provisioned directly for {agent.name}.", - agent_id=agent.id, - ) - record_activity( - session, - event_type="agent.wakeup.sent", - message=f"Wakeup message sent to {agent.name}.", - agent_id=agent.id, - ) - await session.commit() - except OpenClawGatewayError as exc: - _record_instruction_failure(session, agent, str(exc), "provision") - await session.commit() - except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover - _record_instruction_failure(session, agent, str(exc), "provision") - await session.commit() main_session_keys = await _get_gateway_main_session_keys(session) return _to_agent_read(_with_computed_status(agent), main_session_keys) @@ -594,188 +1097,61 @@ async def get_agent( @router.patch("/{agent_id}", response_model=AgentRead) -async def update_agent( # noqa: C901, PLR0912, PLR0913, PLR0915 +async def update_agent( agent_id: str, payload: AgentUpdate, - *, - force: bool = False, + params: _AgentUpdateParams = AGENT_UPDATE_PARAMS_DEP, session: AsyncSession = SESSION_DEP, - auth: AuthContext = AUTH_DEP, - ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> AgentRead: """Update agent metadata and optionally reprovision.""" agent = await Agent.objects.by_id(agent_id).first(session) if agent is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) - await _require_agent_access(session, agent=agent, ctx=ctx, write=True) + await _require_agent_access(session, agent=agent, ctx=params.ctx, write=True) updates = payload.model_dump(exclude_unset=True) make_main = updates.pop("is_gateway_main", None) - if make_main is True and not is_org_admin(ctx.member): - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - if "status" in updates: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="status is controlled by agent heartbeat", - ) - if "board_id" in updates and updates["board_id"] is not None: - new_board = await _require_board(session, updates["board_id"]) - if new_board.organization_id != ctx.organization.id: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) - if not await has_board_access( - session, member=ctx.member, board=new_board, write=True, - ): - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - if not updates and not force and make_main is None: + await _validate_agent_update_inputs( + session, + ctx=params.ctx, + updates=updates, + make_main=make_main, + ) + if not updates and not params.force and make_main is None: main_session_keys = await _get_gateway_main_session_keys(session) return _to_agent_read(_with_computed_status(agent), main_session_keys) - main_gateway = await _find_gateway_for_main_session( - session, agent.openclaw_session_id, + main_gateway, gateway_for_main = await _apply_agent_update_mutations( + session, + agent=agent, + updates=updates, + make_main=make_main, ) - gateway_for_main: Gateway | None = None - if make_main is True: - board_source = updates.get("board_id") or agent.board_id - board_for_main = await _require_board(session, board_source) - gateway_for_main, _ = await _require_gateway(session, board_for_main) - updates["board_id"] = None - agent.is_board_lead = False - agent.openclaw_session_id = gateway_for_main.main_session_key - main_gateway = gateway_for_main - elif make_main is False: - agent.openclaw_session_id = None - if make_main is not True and "board_id" in updates: - await _require_board(session, updates["board_id"]) - for key, value in updates.items(): - setattr(agent, key, value) - if make_main is None and main_gateway is not None: - agent.board_id = None - agent.is_board_lead = False - agent.updated_at = utcnow() - if agent.heartbeat_config is None: - agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy() + target = await _resolve_agent_update_target( + session, + agent=agent, + make_main=make_main, + main_gateway=main_gateway, + gateway_for_main=gateway_for_main, + ) + await _ensure_agent_update_session( + session, + agent=agent, + client_config=target.client_config, + ) + raw_token = _mark_agent_update_pending(agent) session.add(agent) await session.commit() await session.refresh(agent) - is_main_agent = False - board: Board | None = None - gateway: Gateway | None = None - client_config: GatewayClientConfig | None = None - if make_main is True: - is_main_agent = True - gateway = gateway_for_main - elif make_main is None and agent.board_id is None and main_gateway is not None: - is_main_agent = True - gateway = main_gateway - if is_main_agent: - if gateway is None: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="Main agent requires a gateway main_session_key", - ) - if not gateway.main_session_key: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="Gateway main_session_key is required", - ) - client_config = _gateway_client_config(gateway) - else: - if agent.board_id is None: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="board_id is required for non-main agents", - ) - board = await _require_board(session, agent.board_id) - gateway, client_config = await _require_gateway(session, board) - session_key = agent.openclaw_session_id or _build_session_key(agent.name) - try: - if client_config is None: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="Gateway configuration is required", - ) - await ensure_session(session_key, config=client_config, label=agent.name) - if not agent.openclaw_session_id: - agent.openclaw_session_id = session_key - session.add(agent) - await session.commit() - await session.refresh(agent) - except OpenClawGatewayError as exc: - _record_instruction_failure(session, agent, str(exc), "update") - await session.commit() - raw_token = generate_agent_token() - agent.agent_token_hash = hash_agent_token(raw_token) - agent.provision_requested_at = utcnow() - agent.provision_action = "update" - agent.status = "updating" - session.add(agent) - await session.commit() - await session.refresh(agent) - try: - if gateway is None: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="Gateway configuration is required", - ) - if is_main_agent: - await provision_main_agent( - agent, - gateway, - raw_token, - auth.user, - action="update", - force_bootstrap=force, - reset_session=True, - ) - else: - if board is None: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="board is required for non-main agent provisioning", - ) - await provision_agent( - agent, - board, - gateway, - raw_token, - auth.user, - action="update", - force_bootstrap=force, - reset_session=True, - ) - await _send_wakeup_message(agent, client_config, verb="updated") - agent.provision_confirm_token_hash = None - agent.provision_requested_at = None - agent.provision_action = None - agent.status = "online" - agent.updated_at = utcnow() - session.add(agent) - await session.commit() - record_activity( - session, - event_type="agent.update.direct", - message=f"Updated directly for {agent.name}.", - agent_id=agent.id, - ) - record_activity( - session, - event_type="agent.wakeup.sent", - message=f"Wakeup message sent to {agent.name}.", - agent_id=agent.id, - ) - await session.commit() - except OpenClawGatewayError as exc: - _record_instruction_failure(session, agent, str(exc), "update") - await session.commit() - raise HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, - detail=f"Gateway update failed: {exc}", - ) from exc - except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover - _record_instruction_failure(session, agent, str(exc), "update") - await session.commit() - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Unexpected error updating agent provisioning.", - ) from exc + provision_request = _AgentUpdateProvisionRequest( + target=target, + raw_token=raw_token, + user=params.auth.user, + force_bootstrap=params.force, + ) + await _provision_updated_agent( + session, + agent=agent, + request=provision_request, + ) main_session_keys = await _get_gateway_main_session_keys(session) return _to_agent_read(_with_computed_status(agent), main_session_keys) @@ -798,22 +1174,15 @@ async def heartbeat_agent( if not is_org_admin(ctx.member): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) await _require_agent_access(session, agent=agent, ctx=ctx, write=True) - if payload.status: - agent.status = payload.status - elif agent.status == "provisioning": - agent.status = "online" - agent.last_seen_at = utcnow() - agent.updated_at = utcnow() - _record_heartbeat(session, agent) - session.add(agent) - await session.commit() - await session.refresh(agent) - main_session_keys = await _get_gateway_main_session_keys(session) - return _to_agent_read(_with_computed_status(agent), main_session_keys) + return await _commit_heartbeat( + session, + agent=agent, + status_value=payload.status, + ) @router.post("/heartbeat", response_model=AgentRead) -async def heartbeat_or_create_agent( # noqa: C901, PLR0912, PLR0915 +async def heartbeat_or_create_agent( payload: AgentHeartbeatCreate, session: AsyncSession = SESSION_DEP, actor: ActorContext = ACTOR_DEP, @@ -829,179 +1198,32 @@ async def heartbeat_or_create_agent( # noqa: C901, PLR0912, PLR0915 actor=actor, ) - statement = Agent.objects.filter_by(name=payload.name).statement - if payload.board_id is not None: - statement = statement.where(Agent.board_id == payload.board_id) - agent = (await session.exec(statement)).first() + agent = (await session.exec(_heartbeat_lookup_statement(payload))).first() if agent is None: - if actor.actor_type == "agent": - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) - if actor.actor_type == "user": - ctx = await _require_user_context(session, actor.user) - if not is_org_admin(ctx.member): - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - board = await _require_board( + agent = await _create_agent_from_heartbeat( session, - payload.board_id, - user=actor.user, - write=True, + payload=payload, + actor=actor, ) - gateway, client_config = await _require_gateway(session, board) - agent = Agent( - name=payload.name, - status="provisioning", - board_id=board.id, - heartbeat_config=DEFAULT_HEARTBEAT_CONFIG.copy(), - ) - raw_token = generate_agent_token() - agent.agent_token_hash = hash_agent_token(raw_token) - agent.provision_requested_at = utcnow() - agent.provision_action = "provision" - session_key, session_error = await _ensure_gateway_session( - agent.name, client_config, - ) - agent.openclaw_session_id = session_key - session.add(agent) - await session.commit() - await session.refresh(agent) - if session_error: - record_activity( - session, - event_type="agent.session.failed", - message=f"Session sync failed for {agent.name}: {session_error}", - agent_id=agent.id, - ) - else: - record_activity( - session, - event_type="agent.session.created", - message=f"Session created for {agent.name}.", - agent_id=agent.id, - ) - await session.commit() - try: - await provision_agent( - agent, board, gateway, raw_token, actor.user, action="provision", - ) - await _send_wakeup_message(agent, client_config, verb="provisioned") - agent.provision_confirm_token_hash = None - agent.provision_requested_at = None - agent.provision_action = None - agent.updated_at = utcnow() - session.add(agent) - await session.commit() - record_activity( - session, - event_type="agent.provision", - message=f"Provisioned directly for {agent.name}.", - agent_id=agent.id, - ) - record_activity( - session, - event_type="agent.wakeup.sent", - message=f"Wakeup message sent to {agent.name}.", - agent_id=agent.id, - ) - await session.commit() - except OpenClawGatewayError as exc: - _record_instruction_failure(session, agent, str(exc), "provision") - await session.commit() - except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover - _record_instruction_failure(session, agent, str(exc), "provision") - await session.commit() elif actor.actor_type == "user": - ctx = await _require_user_context(session, actor.user) - await _require_agent_access(session, agent=agent, ctx=ctx, write=True) - - if agent.agent_token_hash is None: - raw_token = generate_agent_token() - agent.agent_token_hash = hash_agent_token(raw_token) - if agent.heartbeat_config is None: - agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy() - agent.provision_requested_at = utcnow() - agent.provision_action = "provision" - session.add(agent) - await session.commit() - await session.refresh(agent) - try: - board = await _require_board( - session, - str(agent.board_id) if agent.board_id else None, - user=actor.user, - write=True, - ) - gateway, client_config = await _require_gateway(session, board) - await provision_agent( - agent, board, gateway, raw_token, actor.user, action="provision", - ) - await _send_wakeup_message(agent, client_config, verb="provisioned") - agent.provision_confirm_token_hash = None - agent.provision_requested_at = None - agent.provision_action = None - agent.updated_at = utcnow() - session.add(agent) - await session.commit() - record_activity( - session, - event_type="agent.provision", - message=f"Provisioned directly for {agent.name}.", - agent_id=agent.id, - ) - record_activity( - session, - event_type="agent.wakeup.sent", - message=f"Wakeup message sent to {agent.name}.", - agent_id=agent.id, - ) - await session.commit() - except OpenClawGatewayError as exc: - _record_instruction_failure(session, agent, str(exc), "provision") - await session.commit() - except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover - _record_instruction_failure(session, agent, str(exc), "provision") - await session.commit() + await _handle_existing_user_heartbeat_agent( + session, + agent=agent, + user=actor.user, + ) elif actor.actor_type == "agent" and actor.agent and actor.agent.id != agent.id: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - if not agent.openclaw_session_id: - board = await _require_board( - session, - str(agent.board_id) if agent.board_id else None, - user=actor.user if actor.actor_type == "user" else None, - write=actor.actor_type == "user", - ) - gateway, client_config = await _require_gateway(session, board) - session_key, session_error = await _ensure_gateway_session( - agent.name, client_config, - ) - agent.openclaw_session_id = session_key - if session_error: - record_activity( - session, - event_type="agent.session.failed", - message=f"Session sync failed for {agent.name}: {session_error}", - agent_id=agent.id, - ) - else: - record_activity( - session, - event_type="agent.session.created", - message=f"Session created for {agent.name}.", - agent_id=agent.id, - ) - await session.commit() - if payload.status: - agent.status = payload.status - elif agent.status == "provisioning": - agent.status = "online" - agent.last_seen_at = utcnow() - agent.updated_at = utcnow() - _record_heartbeat(session, agent) - session.add(agent) - await session.commit() - await session.refresh(agent) - main_session_keys = await _get_gateway_main_session_keys(session) - return _to_agent_read(_with_computed_status(agent), main_session_keys) + await _ensure_heartbeat_session_key( + session, + agent=agent, + actor=actor, + ) + return await _commit_heartbeat( + session, + agent=agent, + status_value=payload.status, + ) @router.delete("/{agent_id}", response_model=OkResponse) diff --git a/backend/app/api/board_onboarding.py b/backend/app/api/board_onboarding.py index 9caffc88..0059aa2c 100644 --- a/backend/app/api/board_onboarding.py +++ b/backend/app/api/board_onboarding.py @@ -3,9 +3,7 @@ from __future__ import annotations import logging -import re from typing import TYPE_CHECKING -from uuid import uuid4 from fastapi import APIRouter, Depends, HTTPException, status from pydantic import ValidationError @@ -19,7 +17,6 @@ from app.api.deps import ( require_admin_auth, require_admin_or_agent, ) -from app.core.agent_tokens import generate_agent_token, hash_agent_token from app.core.config import settings from app.core.time import utcnow from app.db.session import get_session @@ -29,7 +26,6 @@ from app.integrations.openclaw_gateway import ( ensure_session, send_message, ) -from app.models.agents import Agent from app.models.board_onboarding import BoardOnboardingSession from app.models.gateways import Gateway from app.schemas.board_onboarding import ( @@ -43,7 +39,11 @@ from app.schemas.board_onboarding import ( BoardOnboardingUserProfile, ) from app.schemas.boards import BoardRead -from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_agent +from app.services.board_leads import ( + LeadAgentOptions, + LeadAgentRequest, + ensure_board_lead_agent, +) if TYPE_CHECKING: from sqlmodel.ext.asyncio.session import AsyncSession @@ -72,93 +72,85 @@ async def _gateway_config( return gateway, GatewayClientConfig(url=gateway.url, token=gateway.token) -def _build_session_key(agent_name: str) -> str: - slug = re.sub(r"[^a-z0-9]+", "-", agent_name.lower()).strip("-") - return f"agent:{slug or uuid4().hex}:main" - - -def _lead_agent_name(_board: Board) -> str: - return "Lead Agent" - - -def _lead_session_key(board: Board) -> str: - return f"agent:lead-{board.id}:main" - - -async def _ensure_lead_agent( # noqa: PLR0913 - session: AsyncSession, - board: Board, - gateway: Gateway, - config: GatewayClientConfig, - auth: AuthContext, - *, - agent_name: str | None = None, - identity_profile: dict[str, str] | None = None, -) -> Agent: - existing = ( - await Agent.objects.filter_by(board_id=board.id) - .filter(col(Agent.is_board_lead).is_(True)) - .first(session) - ) - if existing: - desired_name = agent_name or _lead_agent_name(board) - if existing.name != desired_name: - existing.name = desired_name - session.add(existing) - await session.commit() - await session.refresh(existing) - return existing - - merged_identity_profile = { - "role": "Board Lead", - "communication_style": "direct, concise, practical", - "emoji": ":gear:", - } - if identity_profile: - merged_identity_profile.update( - { - key: value.strip() - for key, value in identity_profile.items() - if value.strip() - }, - ) - - agent = Agent( - name=agent_name or _lead_agent_name(board), - status="provisioning", - board_id=board.id, - is_board_lead=True, - heartbeat_config=DEFAULT_HEARTBEAT_CONFIG.copy(), - identity_profile=merged_identity_profile, - ) - raw_token = generate_agent_token() - agent.agent_token_hash = hash_agent_token(raw_token) - agent.provision_requested_at = utcnow() - agent.provision_action = "provision" - agent.openclaw_session_id = _lead_session_key(board) - session.add(agent) - await session.commit() - await session.refresh(agent) - +def _parse_draft_user_profile( + draft_goal: object, +) -> BoardOnboardingUserProfile | None: + if not isinstance(draft_goal, dict): + return None + raw_profile = draft_goal.get("user_profile") + if raw_profile is None: + return None try: - await provision_agent( - agent, board, gateway, raw_token, auth.user, action="provision", - ) - await ensure_session(agent.openclaw_session_id, config=config, label=agent.name) - await send_message( - ( - f"Hello {agent.name}. Your workspace has been provisioned.\n\n" - "Start the agent, run BOOT.md, and if BOOTSTRAP.md exists run it once " - "then delete it. Begin heartbeats after startup." - ), - session_key=agent.openclaw_session_id, - config=config, - deliver=True, - ) - except OpenClawGatewayError: - # Best-effort provisioning. Board confirmation should still succeed. - pass - return agent + return BoardOnboardingUserProfile.model_validate(raw_profile) + except ValidationError: + return None + + +def _parse_draft_lead_agent( + draft_goal: object, +) -> BoardOnboardingLeadAgentDraft | None: + if not isinstance(draft_goal, dict): + return None + raw_lead = draft_goal.get("lead_agent") + if raw_lead is None: + return None + try: + return BoardOnboardingLeadAgentDraft.model_validate(raw_lead) + except ValidationError: + return None + + +def _apply_user_profile( + auth: AuthContext, + profile: BoardOnboardingUserProfile | None, +) -> bool: + if auth.user is None or profile is None: + return False + + changed = False + if profile.preferred_name is not None: + auth.user.preferred_name = profile.preferred_name + changed = True + if profile.pronouns is not None: + auth.user.pronouns = profile.pronouns + changed = True + if profile.timezone is not None: + auth.user.timezone = profile.timezone + changed = True + if profile.notes is not None: + auth.user.notes = profile.notes + changed = True + if profile.context is not None: + auth.user.context = profile.context + changed = True + return changed + + +def _lead_agent_options( + lead_agent: BoardOnboardingLeadAgentDraft | None, +) -> LeadAgentOptions: + if lead_agent is None: + return LeadAgentOptions(action="provision") + + lead_identity_profile: dict[str, str] = {} + if lead_agent.identity_profile: + lead_identity_profile.update(lead_agent.identity_profile) + if lead_agent.autonomy_level: + lead_identity_profile["autonomy_level"] = lead_agent.autonomy_level + if lead_agent.verbosity: + lead_identity_profile["verbosity"] = lead_agent.verbosity + if lead_agent.output_format: + lead_identity_profile["output_format"] = lead_agent.output_format + if lead_agent.update_cadence: + lead_identity_profile["update_cadence"] = lead_agent.update_cadence + if lead_agent.custom_instructions: + lead_identity_profile["custom_instructions"] = lead_agent.custom_instructions + + return LeadAgentOptions( + agent_name=lead_agent.name, + identity_profile=lead_identity_profile or None, + action="provision", + ) @router.get("", response_model=BoardOnboardingRead) @@ -400,7 +392,7 @@ async def agent_onboarding_update( @router.post("/confirm", response_model=BoardRead) -async def confirm_onboarding( # noqa: C901, PLR0912, PLR0915 +async def confirm_onboarding( payload: BoardOnboardingConfirm, board: Board = BOARD_USER_WRITE_DEP, session: AsyncSession = SESSION_DEP, @@ -425,73 +417,26 @@ async def confirm_onboarding( # noqa: C901, PLR0912, PLR0915 onboarding.status = "confirmed" onboarding.updated_at = utcnow() - user_profile: BoardOnboardingUserProfile | None = None - lead_agent: BoardOnboardingLeadAgentDraft | None = None - if isinstance(onboarding.draft_goal, dict): - raw_profile = onboarding.draft_goal.get("user_profile") - if raw_profile is not None: - try: - user_profile = BoardOnboardingUserProfile.model_validate(raw_profile) - except ValidationError: - user_profile = None - raw_lead = onboarding.draft_goal.get("lead_agent") - if raw_lead is not None: - try: - lead_agent = BoardOnboardingLeadAgentDraft.model_validate(raw_lead) - except ValidationError: - lead_agent = None + user_profile = _parse_draft_user_profile(onboarding.draft_goal) + if _apply_user_profile(auth, user_profile) and auth.user is not None: + session.add(auth.user) - if auth.user and user_profile: - changed = False - if user_profile.preferred_name is not None: - auth.user.preferred_name = user_profile.preferred_name - changed = True - if user_profile.pronouns is not None: - auth.user.pronouns = user_profile.pronouns - changed = True - if user_profile.timezone is not None: - auth.user.timezone = user_profile.timezone - changed = True - if user_profile.notes is not None: - auth.user.notes = user_profile.notes - changed = True - if user_profile.context is not None: - auth.user.context = user_profile.context - changed = True - if changed: - session.add(auth.user) - - lead_identity_profile: dict[str, str] = {} - lead_name: str | None = None - if lead_agent: - lead_name = lead_agent.name - if lead_agent.identity_profile: - lead_identity_profile.update(lead_agent.identity_profile) - if lead_agent.autonomy_level: - lead_identity_profile["autonomy_level"] = lead_agent.autonomy_level - if lead_agent.verbosity: - lead_identity_profile["verbosity"] = lead_agent.verbosity - if lead_agent.output_format: - lead_identity_profile["output_format"] = lead_agent.output_format - if lead_agent.update_cadence: - lead_identity_profile["update_cadence"] = lead_agent.update_cadence - if lead_agent.custom_instructions: - lead_identity_profile["custom_instructions"] = ( - lead_agent.custom_instructions - ) + lead_agent = _parse_draft_lead_agent(onboarding.draft_goal) + lead_options = _lead_agent_options(lead_agent) gateway, config = await _gateway_config(session, board) session.add(board) session.add(onboarding) await session.commit() await session.refresh(board) - await _ensure_lead_agent( + await ensure_board_lead_agent( session, - board, - gateway, - config, - auth, - agent_name=lead_name, - identity_profile=lead_identity_profile or None, + request=LeadAgentRequest( + board=board, + gateway=gateway, + config=config, + user=auth.user, + options=lead_options, + ), ) return board diff --git a/backend/app/api/gateways.py b/backend/app/api/gateways.py index 59606823..29ed12b3 100644 --- a/backend/app/api/gateways.py +++ b/backend/app/api/gateways.py @@ -34,6 +34,8 @@ from app.schemas.gateways import ( from app.schemas.pagination import DefaultLimitOffsetPage from app.services.agent_provisioning import ( DEFAULT_HEARTBEAT_CONFIG, + MainAgentProvisionRequest, + ProvisionOptions, provision_main_agent, ) from app.services.template_sync import ( @@ -187,7 +189,15 @@ async def _ensure_main_agent( await session.commit() await session.refresh(agent) try: - await provision_main_agent(agent, gateway, raw_token, auth.user, action=action) + await provision_main_agent( + agent, + MainAgentProvisionRequest( + gateway=gateway, + auth_token=raw_token, + user=auth.user, + options=ProvisionOptions(action=action), + ), + ) await ensure_session( gateway.main_session_key, config=GatewayClientConfig(url=gateway.url, token=gateway.token), diff --git a/backend/app/api/tasks.py b/backend/app/api/tasks.py index 8a5517bd..6da057ab 100644 --- a/backend/app/api/tasks.py +++ b/backend/app/api/tasks.py @@ -7,6 +7,7 @@ import json from collections import deque from collections.abc import AsyncIterator, Sequence from contextlib import suppress +from dataclasses import dataclass from datetime import datetime, timezone from typing import TYPE_CHECKING, cast from uuid import UUID @@ -69,6 +70,7 @@ if TYPE_CHECKING: from sqlmodel.ext.asyncio.session import AsyncSession from app.core.auth import AuthContext + from app.models.users import User router = APIRouter(prefix="/boards/{board_id}/tasks", tags=["tasks"]) @@ -366,6 +368,22 @@ async def _notify_agent_on_task_assign( await session.commit() +async def notify_agent_on_task_assign( + *, + session: AsyncSession, + board: Board, + task: Task, + agent: Agent, +) -> None: + """Notify an assignee via gateway after task assignment.""" + await _notify_agent_on_task_assign( + session=session, + board=board, + task=task, + agent=agent, + ) + + async def _notify_lead_on_task_create( *, session: AsyncSession, @@ -476,8 +494,194 @@ async def _notify_lead_on_task_unassigned( await session.commit() +def _status_values(status_filter: str | None) -> list[str]: + if not status_filter: + return [] + values = [s.strip() for s in status_filter.split(",") if s.strip()] + if any(value not in ALLOWED_STATUSES for value in values): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Unsupported task status filter.", + ) + return values + + +def _task_list_statement( + *, + board_id: UUID, + status_filter: str | None, + assigned_agent_id: UUID | None, + unassigned: bool | None, +) -> object: + statement = select(Task).where(Task.board_id == board_id) + statuses = _status_values(status_filter) + if statuses: + statement = statement.where(col(Task.status).in_(statuses)) + if assigned_agent_id is not None: + statement = statement.where(col(Task.assigned_agent_id) == assigned_agent_id) + if unassigned: + statement = statement.where(col(Task.assigned_agent_id).is_(None)) + return statement.order_by(col(Task.created_at).desc()) + + +async def _task_read_page( + *, + session: AsyncSession, + board_id: UUID, + tasks: Sequence[Task], +) -> list[TaskRead]: + if not tasks: + return [] + + task_ids = [task.id for task in tasks] + deps_map = await dependency_ids_by_task_id( + session, + board_id=board_id, + task_ids=task_ids, + ) + dep_ids: list[UUID] = [] + for value in deps_map.values(): + dep_ids.extend(value) + dep_status = await dependency_status_by_id( + session, + board_id=board_id, + dependency_ids=list({*dep_ids}), + ) + + output: list[TaskRead] = [] + for task in tasks: + dep_list = deps_map.get(task.id, []) + blocked_by = blocked_by_dependency_ids( + dependency_ids=dep_list, + status_by_id=dep_status, + ) + if task.status == "done": + blocked_by = [] + output.append( + TaskRead.model_validate(task, from_attributes=True).model_copy( + update={ + "depends_on_task_ids": dep_list, + "blocked_by_task_ids": blocked_by, + "is_blocked": bool(blocked_by), + }, + ), + ) + return output + + +async def _stream_dependency_state( + session: AsyncSession, + *, + board_id: UUID, + rows: list[tuple[ActivityEvent, Task | None]], +) -> tuple[dict[UUID, list[UUID]], dict[UUID, str]]: + task_ids = [ + task.id + for event, task in rows + if task is not None and event.event_type != "task.comment" + ] + if not task_ids: + return {}, {} + + deps_map = await dependency_ids_by_task_id( + session, + board_id=board_id, + task_ids=list({*task_ids}), + ) + dep_ids: list[UUID] = [] + for value in deps_map.values(): + dep_ids.extend(value) + if not dep_ids: + return deps_map, {} + + dep_status = await dependency_status_by_id( + session, + board_id=board_id, + dependency_ids=list({*dep_ids}), + ) + return deps_map, dep_status + + +def _task_event_payload( + event: ActivityEvent, + task: Task | None, + *, + deps_map: dict[UUID, list[UUID]], + dep_status: dict[UUID, str], +) -> dict[str, object]: + payload: dict[str, object] = {"type": event.event_type} + if event.event_type == "task.comment": + payload["comment"] = _serialize_comment(event) + return payload + if task is None: + payload["task"] = None + return payload + + dep_list = deps_map.get(task.id, []) + blocked_by = blocked_by_dependency_ids( + dependency_ids=dep_list, + status_by_id=dep_status, + ) + if task.status == "done": + blocked_by = [] + payload["task"] = ( + TaskRead.model_validate(task, from_attributes=True) + .model_copy( + update={ + "depends_on_task_ids": dep_list, + "blocked_by_task_ids": blocked_by, + "is_blocked": bool(blocked_by), + }, + ) + .model_dump(mode="json") + ) + return payload + + +async def _task_event_generator( + *, + request: Request, + board_id: UUID, + since_dt: datetime, +) -> AsyncIterator[dict[str, str]]: + last_seen = since_dt + seen_ids: set[UUID] = set() + seen_queue: deque[UUID] = deque() + + while True: + if await request.is_disconnected(): + break + + async with async_session_maker() as session: + rows = await _fetch_task_events(session, board_id, last_seen) + deps_map, dep_status = await _stream_dependency_state( + session, + board_id=board_id, + rows=rows, + ) + + for event, task in rows: + if event.id in seen_ids: + continue + seen_ids.add(event.id) + seen_queue.append(event.id) + if len(seen_queue) > SSE_SEEN_MAX: + oldest = seen_queue.popleft() + seen_ids.discard(oldest) + last_seen = max(event.created_at, last_seen) + + payload = _task_event_payload( + event, + task, + deps_map=deps_map, + dep_status=dep_status, + ) + yield {"event": "task", "data": json.dumps(payload)} + await asyncio.sleep(2) + + @router.get("/stream") -async def stream_tasks( # noqa: C901 +async def stream_tasks( request: Request, board: Board = BOARD_READ_DEP, _actor: ActorContext = ACTOR_DEP, @@ -485,79 +689,18 @@ async def stream_tasks( # noqa: C901 ) -> EventSourceResponse: """Stream task and task-comment events as SSE payloads.""" since_dt = _parse_since(since) or utcnow() - seen_ids: set[UUID] = set() - seen_queue: deque[UUID] = deque() - - async def event_generator() -> AsyncIterator[dict[str, str]]: # noqa: C901 - last_seen = since_dt - while True: - if await request.is_disconnected(): - break - deps_map: dict[UUID, list[UUID]] = {} - dep_status: dict[UUID, str] = {} - async with async_session_maker() as session: - rows = await _fetch_task_events(session, board.id, last_seen) - task_ids = [ - task.id - for event, task in rows - if task is not None and event.event_type != "task.comment" - ] - if task_ids: - deps_map = await dependency_ids_by_task_id( - session, - board_id=board.id, - task_ids=list({*task_ids}), - ) - dep_ids: list[UUID] = [] - for value in deps_map.values(): - dep_ids.extend(value) - if dep_ids: - dep_status = await dependency_status_by_id( - session, - board_id=board.id, - dependency_ids=list({*dep_ids}), - ) - for event, task in rows: - if event.id in seen_ids: - continue - seen_ids.add(event.id) - seen_queue.append(event.id) - if len(seen_queue) > SSE_SEEN_MAX: - oldest = seen_queue.popleft() - seen_ids.discard(oldest) - last_seen = max(event.created_at, last_seen) - payload: dict[str, object] = {"type": event.event_type} - if event.event_type == "task.comment": - payload["comment"] = _serialize_comment(event) - elif task is None: - payload["task"] = None - else: - dep_list = deps_map.get(task.id, []) - blocked_by = blocked_by_dependency_ids( - dependency_ids=dep_list, - status_by_id=dep_status, - ) - if task.status == "done": - blocked_by = [] - payload["task"] = ( - TaskRead.model_validate(task, from_attributes=True) - .model_copy( - update={ - "depends_on_task_ids": dep_list, - "blocked_by_task_ids": blocked_by, - "is_blocked": bool(blocked_by), - }, - ) - .model_dump(mode="json") - ) - yield {"event": "task", "data": json.dumps(payload)} - await asyncio.sleep(2) - - return EventSourceResponse(event_generator(), ping=15) + return EventSourceResponse( + _task_event_generator( + request=request, + board_id=board.id, + since_dt=since_dt, + ), + ping=15, + ) @router.get("", response_model=DefaultLimitOffsetPage[TaskRead]) -async def list_tasks( # noqa: C901 +async def list_tasks( status_filter: str | None = STATUS_QUERY, assigned_agent_id: UUID | None = None, unassigned: bool | None = None, @@ -566,58 +709,21 @@ async def list_tasks( # noqa: C901 _actor: ActorContext = ACTOR_DEP, ) -> DefaultLimitOffsetPage[TaskRead]: """List board tasks with optional status and assignment filters.""" - statement = select(Task).where(Task.board_id == board.id) - if status_filter: - statuses = [s.strip() for s in status_filter.split(",") if s.strip()] - if statuses: - if any(status_value not in ALLOWED_STATUSES for status_value in statuses): - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="Unsupported task status filter.", - ) - statement = statement.where(col(Task.status).in_(statuses)) - if assigned_agent_id is not None: - statement = statement.where(col(Task.assigned_agent_id) == assigned_agent_id) - if unassigned: - statement = statement.where(col(Task.assigned_agent_id).is_(None)) - statement = statement.order_by(col(Task.created_at).desc()) + statement = _task_list_statement( + board_id=board.id, + status_filter=status_filter, + assigned_agent_id=assigned_agent_id, + unassigned=unassigned, + ) async def _transform(items: Sequence[object]) -> Sequence[object]: tasks = cast(Sequence[Task], items) - if not tasks: - return [] - task_ids = [task.id for task in tasks] - deps_map = await dependency_ids_by_task_id( - session, board_id=board.id, task_ids=task_ids, - ) - dep_ids: list[UUID] = [] - for value in deps_map.values(): - dep_ids.extend(value) - dep_status = await dependency_status_by_id( - session, + return await _task_read_page( + session=session, board_id=board.id, - dependency_ids=list({*dep_ids}), + tasks=tasks, ) - output: list[TaskRead] = [] - for task in tasks: - dep_list = deps_map.get(task.id, []) - blocked_by = blocked_by_dependency_ids( - dependency_ids=dep_list, status_by_id=dep_status, - ) - if task.status == "done": - blocked_by = [] - output.append( - TaskRead.model_validate(task, from_attributes=True).model_copy( - update={ - "depends_on_task_ids": dep_list, - "blocked_by_task_ids": blocked_by, - "is_blocked": bool(blocked_by), - }, - ), - ) - return output - return await paginate(session, statement, transformer=_transform) @@ -700,7 +806,7 @@ async def create_task( response_model=TaskRead, responses={409: {"model": BlockedTaskError}}, ) -async def update_task( # noqa: C901, PLR0912, PLR0915 +async def update_task( payload: TaskUpdate, task: Task = TASK_DEP, session: AsyncSession = SESSION_DEP, @@ -714,359 +820,38 @@ async def update_task( # noqa: C901, PLR0912, PLR0915 ) board_id = task.board_id if actor.actor_type == "user" and actor.user is not None: - board = await Board.objects.by_id(board_id).first(session) - if board is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) - await require_board_access(session, user=actor.user, board=board, write=True) - + await _require_task_user_write_access( + session, + board_id=board_id, + user=actor.user, + ) previous_status = task.status previous_assigned = task.assigned_agent_id updates = payload.model_dump(exclude_unset=True) - comment = updates.pop("comment", None) + comment = cast(str | None, updates.pop("comment", None)) depends_on_task_ids = cast( list[UUID] | None, updates.pop("depends_on_task_ids", None), ) - - requested_fields = set(updates) - if comment is not None: - requested_fields.add("comment") - if depends_on_task_ids is not None: - requested_fields.add("depends_on_task_ids") - - async def _current_dep_ids() -> list[UUID]: - deps_map = await dependency_ids_by_task_id( - session, board_id=board_id, task_ids=[task.id], - ) - return deps_map.get(task.id, []) - - async def _blocked_by(dep_ids: Sequence[UUID]) -> list[UUID]: - if not dep_ids: - return [] - dep_status = await dependency_status_by_id( - session, - board_id=board_id, - dependency_ids=list(dep_ids), - ) - return blocked_by_dependency_ids( - dependency_ids=list(dep_ids), status_by_id=dep_status, - ) - - # Lead agent: delegation only. - # Assign/unassign, resolve review, and manage dependencies. - if actor.actor_type == "agent" and actor.agent and actor.agent.is_board_lead: - allowed_fields = {"assigned_agent_id", "status", "depends_on_task_ids"} - if comment is not None or not requested_fields.issubset(allowed_fields): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=( - "Board leads can only assign/unassign tasks, update " - "dependencies, or resolve review tasks." - ), - ) - - normalized_deps: list[UUID] | None = None - if depends_on_task_ids is not None: - if task.status == "done": - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail=("Cannot change task dependencies after a task is done."), - ) - normalized_deps = await replace_task_dependencies( - session, - board_id=board_id, - task_id=task.id, - depends_on_task_ids=depends_on_task_ids, - ) - - effective_deps = ( - normalized_deps if normalized_deps is not None else await _current_dep_ids() - ) - blocked_by = await _blocked_by(effective_deps) - - # Blocked tasks cannot be assigned or moved out of inbox (unless already done). - if blocked_by and task.status != "done": - task.status = "inbox" - task.assigned_agent_id = None - task.in_progress_at = None - else: - if "assigned_agent_id" in updates: - assigned_id = updates["assigned_agent_id"] - if assigned_id: - agent = await Agent.objects.by_id(assigned_id).first(session) - if agent is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) - if agent.is_board_lead: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Board leads cannot assign tasks to themselves.", - ) - if ( - agent.board_id - and task.board_id - and agent.board_id != task.board_id - ): - raise HTTPException(status_code=status.HTTP_409_CONFLICT) - task.assigned_agent_id = agent.id - else: - task.assigned_agent_id = None - - if "status" in updates: - if task.status != "review": - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=( - "Board leads can only change status when a task is " - "in review." - ), - ) - if updates["status"] not in {"done", "inbox"}: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=( - "Board leads can only move review tasks to done " - "or inbox." - ), - ) - if updates["status"] == "inbox": - task.assigned_agent_id = None - task.in_progress_at = None - task.status = updates["status"] - - task.updated_at = utcnow() - session.add(task) - if task.status != previous_status: - event_type = "task.status_changed" - message = f"Task moved to {task.status}: {task.title}." - else: - event_type = "task.updated" - message = f"Task updated: {task.title}." - record_activity( - session, - event_type=event_type, - task_id=task.id, - message=message, - agent_id=actor.agent.id, - ) - await _reconcile_dependents_for_dependency_toggle( - session, - board_id=board_id, - dependency_task=task, - previous_status=previous_status, - actor_agent_id=actor.agent.id, - ) - await session.commit() - await session.refresh(task) - - if task.assigned_agent_id and task.assigned_agent_id != previous_assigned: - assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first( - session, - ) - if assigned_agent: - board = ( - await Board.objects.by_id(task.board_id).first(session) - if task.board_id - else None - ) - if board: - await _notify_agent_on_task_assign( - session=session, - board=board, - task=task, - agent=assigned_agent, - ) - - dep_ids = await _current_dep_ids() - blocked_ids = await _blocked_by(dep_ids) - if task.status == "done": - blocked_ids = [] - return TaskRead.model_validate(task, from_attributes=True).model_copy( - update={ - "depends_on_task_ids": dep_ids, - "blocked_by_task_ids": blocked_ids, - "is_blocked": bool(blocked_ids), - }, - ) - - # Non-lead agent: can only change status + comment, and cannot start blocked tasks. - if actor.actor_type == "agent": - if ( - actor.agent - and actor.agent.board_id - and task.board_id - and actor.agent.board_id != task.board_id - ): - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - allowed_fields = {"status", "comment"} - if depends_on_task_ids is not None or not set(updates).issubset(allowed_fields): - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - if "status" in updates: - if updates["status"] != "inbox": - dep_ids = await _current_dep_ids() - blocked_ids = await _blocked_by(dep_ids) - if blocked_ids: - raise _blocked_task_error(blocked_ids) - if updates["status"] == "inbox": - task.assigned_agent_id = None - task.in_progress_at = None - else: - task.assigned_agent_id = actor.agent.id if actor.agent else None - if updates["status"] == "in_progress": - task.in_progress_at = utcnow() - else: - # Admin user: dependencies can be edited until the task is done. - admin_normalized_deps: list[UUID] | None = None - if depends_on_task_ids is not None: - if task.status == "done": - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail=("Cannot change task dependencies after a task is done."), - ) - admin_normalized_deps = await replace_task_dependencies( - session, - board_id=board_id, - task_id=task.id, - depends_on_task_ids=depends_on_task_ids, - ) - - effective_deps = ( - admin_normalized_deps - if admin_normalized_deps is not None - else await _current_dep_ids() - ) - blocked_ids = await _blocked_by(effective_deps) - - target_status = cast(str, updates.get("status", task.status)) - if blocked_ids and not (task.status == "done" and target_status == "done"): - # Blocked tasks cannot be assigned or moved out of inbox. - # If the task is already in flight, force it back to inbox and unassign it. - task.status = "inbox" - task.assigned_agent_id = None - task.in_progress_at = None - updates["status"] = "inbox" - updates["assigned_agent_id"] = None - - if "status" in updates: - if updates["status"] == "inbox": - task.assigned_agent_id = None - task.in_progress_at = None - elif updates["status"] == "in_progress": - task.in_progress_at = utcnow() - - assigned_agent_id = updates.get("assigned_agent_id") - if assigned_agent_id: - agent = await Agent.objects.by_id(assigned_agent_id).first(session) - if agent is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) - if agent.board_id and task.board_id and agent.board_id != task.board_id: - raise HTTPException(status_code=status.HTTP_409_CONFLICT) - - for key, value in updates.items(): - setattr(task, key, value) - task.updated_at = utcnow() - - if "status" in updates and updates["status"] == "review": - comment_text = (comment or "").strip() - if not comment_text and not await has_valid_recent_comment( - session, - task, - task.assigned_agent_id, - task.in_progress_at, - ): - raise _comment_validation_error() - - session.add(task) - await session.commit() - await session.refresh(task) - - if comment is not None and comment.strip(): - event = ActivityEvent( - event_type="task.comment", - message=comment, - task_id=task.id, - agent_id=actor.agent.id - if actor.actor_type == "agent" and actor.agent - else None, - ) - session.add(event) - await session.commit() - - if "status" in updates and task.status != previous_status: - event_type = "task.status_changed" - message = f"Task moved to {task.status}: {task.title}." - else: - event_type = "task.updated" - message = f"Task updated: {task.title}." - actor_agent_id = ( - actor.agent.id if actor.actor_type == "agent" and actor.agent else None - ) - record_activity( - session, - event_type=event_type, - task_id=task.id, - message=message, - agent_id=actor_agent_id, - ) - await _reconcile_dependents_for_dependency_toggle( - session, + update = _TaskUpdateInput( + task=task, + actor=actor, board_id=board_id, - dependency_task=task, previous_status=previous_status, - actor_agent_id=actor_agent_id, + previous_assigned=previous_assigned, + updates=updates, + comment=comment, + depends_on_task_ids=depends_on_task_ids, ) - await session.commit() + if actor.actor_type == "agent" and actor.agent and actor.agent.is_board_lead: + return await _apply_lead_task_update(session, update=update) - if ( - task.status == "inbox" - and task.assigned_agent_id is None - and (previous_status != "inbox" or previous_assigned is not None) - ): - board = ( - await Board.objects.by_id(task.board_id).first(session) - if task.board_id - else None - ) - if board: - await _notify_lead_on_task_unassigned( - session=session, - board=board, - task=task, - ) - if task.assigned_agent_id and task.assigned_agent_id != previous_assigned: - if ( - actor.actor_type == "agent" - and actor.agent - and task.assigned_agent_id == actor.agent.id - ): - # Don't notify the actor about their own assignment. - pass - else: - assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first( - session, - ) - if assigned_agent: - board = ( - await Board.objects.by_id(task.board_id).first(session) - if task.board_id - else None - ) - if board: - await _notify_agent_on_task_assign( - session=session, - board=board, - task=task, - agent=assigned_agent, - ) - - dep_ids = await _current_dep_ids() - blocked_ids = await _blocked_by(dep_ids) - if task.status == "done": - blocked_ids = [] - return TaskRead.model_validate(task, from_attributes=True).model_copy( - update={ - "depends_on_task_ids": dep_ids, - "blocked_by_task_ids": blocked_ids, - "is_blocked": bool(blocked_ids), - }, + if actor.actor_type == "agent": + await _apply_non_lead_agent_task_rules(session, update=update) + else: + await _apply_admin_task_rules(session, update=update) + return await _finalize_updated_task( + session, + update=update, ) @@ -1125,21 +910,21 @@ async def list_task_comments( return await paginate(session, statement) -@router.post("/{task_id}/comments", response_model=TaskCommentRead) -async def create_task_comment( # noqa: C901, PLR0912 - payload: TaskCommentCreate, - task: Task = TASK_DEP, - session: AsyncSession = SESSION_DEP, - actor: ActorContext = ACTOR_DEP, -) -> ActivityEvent: - """Create a task comment and notify relevant agents.""" +async def _validate_task_comment_access( + session: AsyncSession, + *, + task: Task, + actor: ActorContext, +) -> None: if task.board_id is None: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) + if actor.actor_type == "user" and actor.user is not None: board = await Board.objects.by_id(task.board_id).first(session) if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) await require_board_access(session, user=actor.user, board=board, write=True) + if ( actor.actor_type == "agent" and actor.agent @@ -1155,18 +940,28 @@ async def create_task_comment( # noqa: C901, PLR0912 "or on tasks they created." ), ) - event = ActivityEvent( - event_type="task.comment", - message=payload.message, - task_id=task.id, - agent_id=actor.agent.id - if actor.actor_type == "agent" and actor.agent - else None, - ) - session.add(event) - await session.commit() - await session.refresh(event) - mention_names = extract_mentions(payload.message) + + +def _comment_actor_id(actor: ActorContext) -> UUID | None: + if actor.actor_type == "agent" and actor.agent: + return actor.agent.id + return None + + +def _comment_actor_name(actor: ActorContext) -> str: + if actor.actor_type == "agent" and actor.agent: + return actor.agent.name + return "User" + + +async def _comment_targets( + session: AsyncSession, + *, + task: Task, + message: str, + actor: ActorContext, +) -> tuple[dict[UUID, Agent], list[str]]: + mention_names = extract_mentions(message) targets: dict[UUID, Agent] = {} if mention_names and task.board_id: for agent in await Agent.objects.filter_by(board_id=task.board_id).all(session): @@ -1178,48 +973,619 @@ async def create_task_comment( # noqa: C901, PLR0912 ) if assigned_agent: targets[assigned_agent.id] = assigned_agent + if actor.actor_type == "agent" and actor.agent: targets.pop(actor.agent.id, None) - if targets: + return targets, mention_names + + +@dataclass(frozen=True, slots=True) +class _TaskCommentNotifyRequest: + task: Task + actor: ActorContext + message: str + targets: dict[UUID, Agent] + mention_names: list[str] + + +async def _notify_task_comment_targets( + session: AsyncSession, + *, + request: _TaskCommentNotifyRequest, +) -> None: + if not request.targets: + return + board = ( + await Board.objects.by_id(request.task.board_id).first(session) + if request.task.board_id + else None + ) + config = await _gateway_config(session, board) if board else None + if not board or not config: + return + + snippet = _truncate_snippet(request.message) + actor_name = _comment_actor_name(request.actor) + for agent in request.targets.values(): + if not agent.openclaw_session_id: + continue + mentioned = matches_agent_mention(agent, request.mention_names) + header = "TASK MENTION" if mentioned else "NEW TASK COMMENT" + action_line = ( + "You were mentioned in this comment." + if mentioned + else "A new comment was posted on your task." + ) + notification = ( + f"{header}\n" + f"Board: {board.name}\n" + f"Task: {request.task.title}\n" + f"Task ID: {request.task.id}\n" + f"From: {actor_name}\n\n" + f"{action_line}\n\n" + f"Comment:\n{snippet}\n\n" + "If you are mentioned but not assigned, reply in the task " + "thread but do not change task status." + ) + with suppress(OpenClawGatewayError): + await _send_agent_task_message( + session_key=agent.openclaw_session_id, + config=config, + agent_name=agent.name, + message=notification, + ) + + +@dataclass(slots=True) +class _TaskUpdateInput: + task: Task + actor: ActorContext + board_id: UUID + previous_status: str + previous_assigned: UUID | None + updates: dict[str, object] + comment: str | None + depends_on_task_ids: list[UUID] | None + + +async def _task_dep_ids( + session: AsyncSession, + *, + board_id: UUID, + task_id: UUID, +) -> list[UUID]: + deps_map = await dependency_ids_by_task_id( + session, + board_id=board_id, + task_ids=[task_id], + ) + return deps_map.get(task_id, []) + + +async def _task_blocked_ids( + session: AsyncSession, + *, + board_id: UUID, + dep_ids: Sequence[UUID], +) -> list[UUID]: + if not dep_ids: + return [] + dep_status = await dependency_status_by_id( + session, + board_id=board_id, + dependency_ids=list(dep_ids), + ) + return blocked_by_dependency_ids( + dependency_ids=list(dep_ids), + status_by_id=dep_status, + ) + + +async def _task_read_response( + session: AsyncSession, + *, + task: Task, + board_id: UUID, +) -> TaskRead: + dep_ids = await _task_dep_ids(session, board_id=board_id, task_id=task.id) + blocked_ids = await _task_blocked_ids( + session, + board_id=board_id, + dep_ids=dep_ids, + ) + if task.status == "done": + blocked_ids = [] + return TaskRead.model_validate(task, from_attributes=True).model_copy( + update={ + "depends_on_task_ids": dep_ids, + "blocked_by_task_ids": blocked_ids, + "is_blocked": bool(blocked_ids), + }, + ) + + +async def _require_task_user_write_access( + session: AsyncSession, + *, + board_id: UUID, + user: User | None, +) -> None: + board = await Board.objects.by_id(board_id).first(session) + if board is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + if user is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + await require_board_access(session, user=user, board=board, write=True) + + +def _lead_requested_fields(update: _TaskUpdateInput) -> set[str]: + requested_fields = set(update.updates) + if update.comment is not None: + requested_fields.add("comment") + if update.depends_on_task_ids is not None: + requested_fields.add("depends_on_task_ids") + return requested_fields + + +def _validate_lead_update_request(update: _TaskUpdateInput) -> None: + allowed_fields = {"assigned_agent_id", "status", "depends_on_task_ids"} + requested_fields = _lead_requested_fields(update) + if update.comment is not None or not requested_fields.issubset(allowed_fields): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=( + "Board leads can only assign/unassign tasks, update " + "dependencies, or resolve review tasks." + ), + ) + + +async def _lead_effective_dependencies( + session: AsyncSession, + *, + update: _TaskUpdateInput, +) -> tuple[list[UUID], list[UUID]]: + normalized_deps: list[UUID] | None = None + if update.depends_on_task_ids is not None: + if update.task.status == "done": + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=("Cannot change task dependencies after a task is done."), + ) + normalized_deps = await replace_task_dependencies( + session, + board_id=update.board_id, + task_id=update.task.id, + depends_on_task_ids=update.depends_on_task_ids, + ) + effective_deps = ( + normalized_deps + if normalized_deps is not None + else await _task_dep_ids( + session, + board_id=update.board_id, + task_id=update.task.id, + ) + ) + blocked_by = await _task_blocked_ids( + session, + board_id=update.board_id, + dep_ids=effective_deps, + ) + return effective_deps, blocked_by + + +async def _lead_apply_assignment( + session: AsyncSession, + *, + update: _TaskUpdateInput, +) -> None: + if "assigned_agent_id" not in update.updates: + return + assigned_id = cast(UUID | None, update.updates["assigned_agent_id"]) + if not assigned_id: + update.task.assigned_agent_id = None + return + agent = await Agent.objects.by_id(assigned_id).first(session) + if agent is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + if agent.is_board_lead: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Board leads cannot assign tasks to themselves.", + ) + if ( + agent.board_id + and update.task.board_id + and agent.board_id != update.task.board_id + ): + raise HTTPException(status_code=status.HTTP_409_CONFLICT) + update.task.assigned_agent_id = agent.id + + +def _lead_apply_status(update: _TaskUpdateInput) -> None: + if "status" not in update.updates: + return + if update.task.status != "review": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=( + "Board leads can only change status when a task is " + "in review." + ), + ) + target_status = cast(str, update.updates["status"]) + if target_status not in {"done", "inbox"}: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=( + "Board leads can only move review tasks to done " + "or inbox." + ), + ) + if target_status == "inbox": + update.task.assigned_agent_id = None + update.task.in_progress_at = None + update.task.status = target_status + + +def _task_event_details(task: Task, previous_status: str) -> tuple[str, str]: + if task.status != previous_status: + return "task.status_changed", f"Task moved to {task.status}: {task.title}." + return "task.updated", f"Task updated: {task.title}." + + +async def _lead_notify_new_assignee( + session: AsyncSession, + *, + update: _TaskUpdateInput, +) -> None: + if ( + not update.task.assigned_agent_id + or update.task.assigned_agent_id == update.previous_assigned + ): + return + assigned_agent = await Agent.objects.by_id(update.task.assigned_agent_id).first( + session, + ) + if assigned_agent is None: + return + board = ( + await Board.objects.by_id(update.task.board_id).first(session) + if update.task.board_id + else None + ) + if board: + await _notify_agent_on_task_assign( + session=session, + board=board, + task=update.task, + agent=assigned_agent, + ) + + +async def _apply_lead_task_update( + session: AsyncSession, + *, + update: _TaskUpdateInput, +) -> TaskRead: + if update.actor.actor_type != "agent" or update.actor.agent is None: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + _validate_lead_update_request(update) + _effective_deps, blocked_by = await _lead_effective_dependencies( + session, + update=update, + ) + + if blocked_by and update.task.status != "done": + update.task.status = "inbox" + update.task.assigned_agent_id = None + update.task.in_progress_at = None + else: + await _lead_apply_assignment(session, update=update) + _lead_apply_status(update) + + update.task.updated_at = utcnow() + session.add(update.task) + event_type, message = _task_event_details(update.task, update.previous_status) + record_activity( + session, + event_type=event_type, + task_id=update.task.id, + message=message, + agent_id=update.actor.agent.id, + ) + await _reconcile_dependents_for_dependency_toggle( + session, + board_id=update.board_id, + dependency_task=update.task, + previous_status=update.previous_status, + actor_agent_id=update.actor.agent.id, + ) + await session.commit() + await session.refresh(update.task) + await _lead_notify_new_assignee(session, update=update) + return await _task_read_response( + session, + task=update.task, + board_id=update.board_id, + ) + + +async def _apply_non_lead_agent_task_rules( + session: AsyncSession, + *, + update: _TaskUpdateInput, +) -> None: + if update.actor.actor_type != "agent": + return + if ( + update.actor.agent + and update.actor.agent.board_id + and update.task.board_id + and update.actor.agent.board_id != update.task.board_id + ): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + allowed_fields = {"status", "comment"} + if update.depends_on_task_ids is not None or not set(update.updates).issubset( + allowed_fields, + ): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + if "status" in update.updates: + status_value = cast(str, update.updates["status"]) + if status_value != "inbox": + dep_ids = await _task_dep_ids( + session, + board_id=update.board_id, + task_id=update.task.id, + ) + blocked_ids = await _task_blocked_ids( + session, + board_id=update.board_id, + dep_ids=dep_ids, + ) + if blocked_ids: + raise _blocked_task_error(blocked_ids) + if status_value == "inbox": + update.task.assigned_agent_id = None + update.task.in_progress_at = None + else: + update.task.assigned_agent_id = ( + update.actor.agent.id if update.actor.agent else None + ) + if status_value == "in_progress": + update.task.in_progress_at = utcnow() + + +async def _apply_admin_task_rules( + session: AsyncSession, + *, + update: _TaskUpdateInput, +) -> None: + admin_normalized_deps: list[UUID] | None = None + if update.depends_on_task_ids is not None: + if update.task.status == "done": + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=("Cannot change task dependencies after a task is done."), + ) + admin_normalized_deps = await replace_task_dependencies( + session, + board_id=update.board_id, + task_id=update.task.id, + depends_on_task_ids=update.depends_on_task_ids, + ) + + effective_deps = ( + admin_normalized_deps + if admin_normalized_deps is not None + else await _task_dep_ids( + session, + board_id=update.board_id, + task_id=update.task.id, + ) + ) + blocked_ids = await _task_blocked_ids( + session, + board_id=update.board_id, + dep_ids=effective_deps, + ) + target_status = cast(str, update.updates.get("status", update.task.status)) + if blocked_ids and not (update.task.status == "done" and target_status == "done"): + update.task.status = "inbox" + update.task.assigned_agent_id = None + update.task.in_progress_at = None + update.updates["status"] = "inbox" + update.updates["assigned_agent_id"] = None + + if "status" in update.updates: + status_value = cast(str, update.updates["status"]) + if status_value == "inbox": + update.task.assigned_agent_id = None + update.task.in_progress_at = None + elif status_value == "in_progress": + update.task.in_progress_at = utcnow() + + assigned_agent_id = cast(UUID | None, update.updates.get("assigned_agent_id")) + if assigned_agent_id: + agent = await Agent.objects.by_id(assigned_agent_id).first(session) + if agent is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + if ( + agent.board_id + and update.task.board_id + and agent.board_id != update.task.board_id + ): + raise HTTPException(status_code=status.HTTP_409_CONFLICT) + + +async def _record_task_comment_from_update( + session: AsyncSession, + *, + update: _TaskUpdateInput, +) -> None: + if update.comment is None or not update.comment.strip(): + return + event = ActivityEvent( + event_type="task.comment", + message=update.comment, + task_id=update.task.id, + agent_id=update.actor.agent.id + if update.actor.actor_type == "agent" and update.actor.agent + else None, + ) + session.add(event) + await session.commit() + + +async def _record_task_update_activity( + session: AsyncSession, + *, + update: _TaskUpdateInput, +) -> None: + event_type, message = _task_event_details(update.task, update.previous_status) + actor_agent_id = ( + update.actor.agent.id + if update.actor.actor_type == "agent" and update.actor.agent + else None + ) + record_activity( + session, + event_type=event_type, + task_id=update.task.id, + message=message, + agent_id=actor_agent_id, + ) + await _reconcile_dependents_for_dependency_toggle( + session, + board_id=update.board_id, + dependency_task=update.task, + previous_status=update.previous_status, + actor_agent_id=actor_agent_id, + ) + await session.commit() + + +async def _notify_task_update_assignment_changes( + session: AsyncSession, + *, + update: _TaskUpdateInput, +) -> None: + if ( + update.task.status == "inbox" + and update.task.assigned_agent_id is None + and ( + update.previous_status != "inbox" + or update.previous_assigned is not None + ) + ): board = ( - await Board.objects.by_id(task.board_id).first(session) - if task.board_id + await Board.objects.by_id(update.task.board_id).first(session) + if update.task.board_id else None ) - config = await _gateway_config(session, board) if board else None - if board and config: - snippet = _truncate_snippet(payload.message) - actor_name = ( - actor.agent.name - if actor.actor_type == "agent" and actor.agent - else "User" + if board: + await _notify_lead_on_task_unassigned( + session=session, + board=board, + task=update.task, ) - for agent in targets.values(): - if not agent.openclaw_session_id: - continue - mentioned = matches_agent_mention(agent, mention_names) - header = "TASK MENTION" if mentioned else "NEW TASK COMMENT" - action_line = ( - "You were mentioned in this comment." - if mentioned - else "A new comment was posted on your task." - ) - message = ( - f"{header}\n" - f"Board: {board.name}\n" - f"Task: {task.title}\n" - f"Task ID: {task.id}\n" - f"From: {actor_name}\n\n" - f"{action_line}\n\n" - f"Comment:\n{snippet}\n\n" - "If you are mentioned but not assigned, reply in the task " - "thread but do not change task status." - ) - with suppress(OpenClawGatewayError): - await _send_agent_task_message( - session_key=agent.openclaw_session_id, - config=config, - agent_name=agent.name, - message=message, - ) + + if ( + not update.task.assigned_agent_id + or update.task.assigned_agent_id == update.previous_assigned + ): + return + if ( + update.actor.actor_type == "agent" + and update.actor.agent + and update.task.assigned_agent_id == update.actor.agent.id + ): + return + assigned_agent = await Agent.objects.by_id(update.task.assigned_agent_id).first( + session, + ) + if assigned_agent is None: + return + board = ( + await Board.objects.by_id(update.task.board_id).first(session) + if update.task.board_id + else None + ) + if board: + await _notify_agent_on_task_assign( + session=session, + board=board, + task=update.task, + agent=assigned_agent, + ) + + +async def _finalize_updated_task( + session: AsyncSession, + *, + update: _TaskUpdateInput, +) -> TaskRead: + for key, value in update.updates.items(): + setattr(update.task, key, value) + update.task.updated_at = utcnow() + + if "status" in update.updates and cast(str, update.updates["status"]) == "review": + comment_text = (update.comment or "").strip() + if not comment_text and not await has_valid_recent_comment( + session, + update.task, + update.task.assigned_agent_id, + update.task.in_progress_at, + ): + raise _comment_validation_error() + + session.add(update.task) + await session.commit() + await session.refresh(update.task) + await _record_task_comment_from_update(session, update=update) + await _record_task_update_activity(session, update=update) + await _notify_task_update_assignment_changes(session, update=update) + + return await _task_read_response( + session, + task=update.task, + board_id=update.board_id, + ) + + +@router.post("/{task_id}/comments", response_model=TaskCommentRead) +async def create_task_comment( + payload: TaskCommentCreate, + task: Task = TASK_DEP, + session: AsyncSession = SESSION_DEP, + actor: ActorContext = ACTOR_DEP, +) -> ActivityEvent: + """Create a task comment and notify relevant agents.""" + await _validate_task_comment_access(session, task=task, actor=actor) + event = ActivityEvent( + event_type="task.comment", + message=payload.message, + task_id=task.id, + agent_id=_comment_actor_id(actor), + ) + session.add(event) + await session.commit() + await session.refresh(event) + targets, mention_names = await _comment_targets( + session, + task=task, + message=payload.message, + actor=actor, + ) + await _notify_task_comment_targets( + session, + request=_TaskCommentNotifyRequest( + task=task, + actor=actor, + message=payload.message, + targets=targets, + mention_names=mention_names, + ), + ) return event diff --git a/backend/app/models/activity_events.py b/backend/app/models/activity_events.py index 8d1087b6..61b26cc9 100644 --- a/backend/app/models/activity_events.py +++ b/backend/app/models/activity_events.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from uuid import UUID, uuid4 from sqlmodel import Field @@ -10,6 +10,8 @@ from sqlmodel import Field from app.core.time import utcnow from app.models.base import QueryModel +RUNTIME_ANNOTATION_TYPES = (datetime,) + class ActivityEvent(QueryModel, table=True): """Discrete activity event tied to tasks and agents.""" diff --git a/backend/app/models/agents.py b/backend/app/models/agents.py index 460f6122..45535420 100644 --- a/backend/app/models/agents.py +++ b/backend/app/models/agents.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from typing import Any from uuid import UUID, uuid4 @@ -12,6 +12,8 @@ from sqlmodel import Field from app.core.time import utcnow from app.models.base import QueryModel +RUNTIME_ANNOTATION_TYPES = (datetime,) + class Agent(QueryModel, table=True): """Agent configuration and lifecycle state persisted in the database.""" diff --git a/backend/app/models/approvals.py b/backend/app/models/approvals.py index 6ec35ec1..b950fe72 100644 --- a/backend/app/models/approvals.py +++ b/backend/app/models/approvals.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import JSON, Column @@ -11,6 +11,8 @@ from sqlmodel import Field from app.core.time import utcnow from app.models.base import QueryModel +RUNTIME_ANNOTATION_TYPES = (datetime,) + class Approval(QueryModel, table=True): """Approval request and decision metadata for gated operations.""" diff --git a/backend/app/models/board_group_memory.py b/backend/app/models/board_group_memory.py index bdbb1bc2..ac10b10e 100644 --- a/backend/app/models/board_group_memory.py +++ b/backend/app/models/board_group_memory.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import JSON, Column @@ -11,6 +11,8 @@ from sqlmodel import Field from app.core.time import utcnow from app.models.base import QueryModel +RUNTIME_ANNOTATION_TYPES = (datetime,) + class BoardGroupMemory(QueryModel, table=True): """Persisted memory items associated with a board group.""" diff --git a/backend/app/models/board_groups.py b/backend/app/models/board_groups.py index a7684b4a..26cbda50 100644 --- a/backend/app/models/board_groups.py +++ b/backend/app/models/board_groups.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from uuid import UUID, uuid4 from sqlmodel import Field @@ -10,6 +10,8 @@ from sqlmodel import Field from app.core.time import utcnow from app.models.tenancy import TenantScoped +RUNTIME_ANNOTATION_TYPES = (datetime,) + class BoardGroup(TenantScoped, table=True): """Logical grouping container for boards within an organization.""" diff --git a/backend/app/models/board_memory.py b/backend/app/models/board_memory.py index e0f8b1a2..632ac215 100644 --- a/backend/app/models/board_memory.py +++ b/backend/app/models/board_memory.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import JSON, Column @@ -11,6 +11,8 @@ from sqlmodel import Field from app.core.time import utcnow from app.models.base import QueryModel +RUNTIME_ANNOTATION_TYPES = (datetime,) + class BoardMemory(QueryModel, table=True): """Persisted memory item attached directly to a board.""" diff --git a/backend/app/models/board_onboarding.py b/backend/app/models/board_onboarding.py index 5791ed72..9d60494d 100644 --- a/backend/app/models/board_onboarding.py +++ b/backend/app/models/board_onboarding.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import JSON, Column @@ -11,6 +11,8 @@ from sqlmodel import Field from app.core.time import utcnow from app.models.base import QueryModel +RUNTIME_ANNOTATION_TYPES = (datetime,) + class BoardOnboardingSession(QueryModel, table=True): """Persisted onboarding conversation and draft goal data for a board.""" diff --git a/backend/app/models/boards.py b/backend/app/models/boards.py index 7864b985..ea37cf72 100644 --- a/backend/app/models/boards.py +++ b/backend/app/models/boards.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import JSON, Column @@ -11,6 +11,8 @@ from sqlmodel import Field from app.core.time import utcnow from app.models.tenancy import TenantScoped +RUNTIME_ANNOTATION_TYPES = (datetime,) + class Board(TenantScoped, table=True): """Primary board entity grouping tasks, agents, and goal metadata.""" diff --git a/backend/app/models/gateways.py b/backend/app/models/gateways.py index c2d744c0..9e41d883 100644 --- a/backend/app/models/gateways.py +++ b/backend/app/models/gateways.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from uuid import UUID, uuid4 from sqlmodel import Field @@ -10,6 +10,8 @@ from sqlmodel import Field from app.core.time import utcnow from app.models.base import QueryModel +RUNTIME_ANNOTATION_TYPES = (datetime,) + class Gateway(QueryModel, table=True): """Configured external gateway endpoint and authentication settings.""" diff --git a/backend/app/models/organization_board_access.py b/backend/app/models/organization_board_access.py index dc5f74af..a2da23b1 100644 --- a/backend/app/models/organization_board_access.py +++ b/backend/app/models/organization_board_access.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import UniqueConstraint @@ -11,6 +11,8 @@ from sqlmodel import Field from app.core.time import utcnow from app.models.base import QueryModel +RUNTIME_ANNOTATION_TYPES = (datetime,) + class OrganizationBoardAccess(QueryModel, table=True): """Member-specific board permissions within an organization.""" diff --git a/backend/app/models/organization_invite_board_access.py b/backend/app/models/organization_invite_board_access.py index bdb31bc9..28dbbf0a 100644 --- a/backend/app/models/organization_invite_board_access.py +++ b/backend/app/models/organization_invite_board_access.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import UniqueConstraint @@ -11,6 +11,8 @@ from sqlmodel import Field from app.core.time import utcnow from app.models.base import QueryModel +RUNTIME_ANNOTATION_TYPES = (datetime,) + class OrganizationInviteBoardAccess(QueryModel, table=True): """Invite-specific board permissions applied after invite acceptance.""" diff --git a/backend/app/models/organization_invites.py b/backend/app/models/organization_invites.py index 3f416bc8..2071d57d 100644 --- a/backend/app/models/organization_invites.py +++ b/backend/app/models/organization_invites.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import UniqueConstraint @@ -11,6 +11,8 @@ from sqlmodel import Field from app.core.time import utcnow from app.models.base import QueryModel +RUNTIME_ANNOTATION_TYPES = (datetime,) + class OrganizationInvite(QueryModel, table=True): """Invitation record granting prospective organization access.""" diff --git a/backend/app/models/organization_members.py b/backend/app/models/organization_members.py index 18d6506e..f3a97153 100644 --- a/backend/app/models/organization_members.py +++ b/backend/app/models/organization_members.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import UniqueConstraint @@ -11,6 +11,8 @@ from sqlmodel import Field from app.core.time import utcnow from app.models.base import QueryModel +RUNTIME_ANNOTATION_TYPES = (datetime,) + class OrganizationMember(QueryModel, table=True): """Membership row linking a user to an organization and permissions.""" diff --git a/backend/app/models/organizations.py b/backend/app/models/organizations.py index 5fd9cc49..cd4c8c18 100644 --- a/backend/app/models/organizations.py +++ b/backend/app/models/organizations.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import UniqueConstraint @@ -11,6 +11,8 @@ from sqlmodel import Field from app.core.time import utcnow from app.models.base import QueryModel +RUNTIME_ANNOTATION_TYPES = (datetime,) + class Organization(QueryModel, table=True): """Top-level organization tenant record.""" diff --git a/backend/app/models/task_dependencies.py b/backend/app/models/task_dependencies.py index 77e8245b..4a3a9e9c 100644 --- a/backend/app/models/task_dependencies.py +++ b/backend/app/models/task_dependencies.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import CheckConstraint, UniqueConstraint @@ -11,6 +11,8 @@ from sqlmodel import Field from app.core.time import utcnow from app.models.tenancy import TenantScoped +RUNTIME_ANNOTATION_TYPES = (datetime,) + class TaskDependency(TenantScoped, table=True): """Directed dependency edge between two tasks in the same board.""" diff --git a/backend/app/models/task_fingerprints.py b/backend/app/models/task_fingerprints.py index 997025fa..a55c30f0 100644 --- a/backend/app/models/task_fingerprints.py +++ b/backend/app/models/task_fingerprints.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from uuid import UUID, uuid4 from sqlmodel import Field @@ -10,6 +10,8 @@ from sqlmodel import Field from app.core.time import utcnow from app.models.base import QueryModel +RUNTIME_ANNOTATION_TYPES = (datetime,) + class TaskFingerprint(QueryModel, table=True): """Hashed task-content fingerprint associated with a board and task.""" diff --git a/backend/app/models/tasks.py b/backend/app/models/tasks.py index b971ea5d..6c8285fe 100644 --- a/backend/app/models/tasks.py +++ b/backend/app/models/tasks.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from uuid import UUID, uuid4 from sqlmodel import Field @@ -10,6 +10,8 @@ from sqlmodel import Field from app.core.time import utcnow from app.models.tenancy import TenantScoped +RUNTIME_ANNOTATION_TYPES = (datetime,) + class Task(TenantScoped, table=True): """Board-scoped task entity with ownership, status, and timing fields.""" diff --git a/backend/app/schemas/activity_events.py b/backend/app/schemas/activity_events.py index 986165e8..8c02dd4b 100644 --- a/backend/app/schemas/activity_events.py +++ b/backend/app/schemas/activity_events.py @@ -2,11 +2,13 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 -from uuid import UUID # noqa: TCH003 +from datetime import datetime +from uuid import UUID from sqlmodel import SQLModel +RUNTIME_ANNOTATION_TYPES = (datetime, UUID) + class ActivityEventRead(SQLModel): """Serialized activity event payload returned by activity endpoints.""" diff --git a/backend/app/schemas/approvals.py b/backend/app/schemas/approvals.py index ca71bef0..d0c95007 100644 --- a/backend/app/schemas/approvals.py +++ b/backend/app/schemas/approvals.py @@ -2,15 +2,16 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from typing import Literal, Self -from uuid import UUID # noqa: TCH003 +from uuid import UUID from pydantic import model_validator from sqlmodel import SQLModel ApprovalStatus = Literal["pending", "approved", "rejected"] STATUS_REQUIRED_ERROR = "status is required" +RUNTIME_ANNOTATION_TYPES = (datetime, UUID) class ApprovalBase(SQLModel): diff --git a/backend/app/schemas/board_group_heartbeat.py b/backend/app/schemas/board_group_heartbeat.py index 3e56b627..40e786f3 100644 --- a/backend/app/schemas/board_group_heartbeat.py +++ b/backend/app/schemas/board_group_heartbeat.py @@ -3,10 +3,12 @@ from __future__ import annotations from typing import Any -from uuid import UUID # noqa: TCH003 +from uuid import UUID from sqlmodel import SQLModel +RUNTIME_ANNOTATION_TYPES = (UUID,) + class BoardGroupHeartbeatApply(SQLModel): """Request payload for heartbeat policy updates.""" diff --git a/backend/app/schemas/board_group_memory.py b/backend/app/schemas/board_group_memory.py index 2a6ffb0f..ef8ba516 100644 --- a/backend/app/schemas/board_group_memory.py +++ b/backend/app/schemas/board_group_memory.py @@ -2,12 +2,14 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 -from uuid import UUID # noqa: TCH003 +from datetime import datetime +from uuid import UUID from sqlmodel import SQLModel -from app.schemas.common import NonEmptyStr # noqa: TCH001 +from app.schemas.common import NonEmptyStr + +RUNTIME_ANNOTATION_TYPES = (datetime, UUID, NonEmptyStr) class BoardGroupMemoryCreate(SQLModel): diff --git a/backend/app/schemas/board_groups.py b/backend/app/schemas/board_groups.py index f5e599e7..69b66f8b 100644 --- a/backend/app/schemas/board_groups.py +++ b/backend/app/schemas/board_groups.py @@ -2,11 +2,13 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 -from uuid import UUID # noqa: TCH003 +from datetime import datetime +from uuid import UUID from sqlmodel import SQLModel +RUNTIME_ANNOTATION_TYPES = (datetime, UUID) + class BoardGroupBase(SQLModel): """Shared board-group fields for create/read operations.""" diff --git a/backend/app/schemas/board_memory.py b/backend/app/schemas/board_memory.py index 88be62e9..96bf2cee 100644 --- a/backend/app/schemas/board_memory.py +++ b/backend/app/schemas/board_memory.py @@ -2,12 +2,14 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 -from uuid import UUID # noqa: TCH003 +from datetime import datetime +from uuid import UUID from sqlmodel import SQLModel -from app.schemas.common import NonEmptyStr # noqa: TCH001 +from app.schemas.common import NonEmptyStr + +RUNTIME_ANNOTATION_TYPES = (datetime, UUID, NonEmptyStr) class BoardMemoryCreate(SQLModel): diff --git a/backend/app/schemas/boards.py b/backend/app/schemas/boards.py index 4af6e907..4ce0bd7c 100644 --- a/backend/app/schemas/boards.py +++ b/backend/app/schemas/boards.py @@ -2,9 +2,9 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from typing import Self -from uuid import UUID # noqa: TCH003 +from uuid import UUID from pydantic import model_validator from sqlmodel import SQLModel @@ -13,6 +13,7 @@ _ERR_GOAL_FIELDS_REQUIRED = ( "Confirmed goal boards require objective and success_metrics" ) _ERR_GATEWAY_REQUIRED = "gateway_id is required" +RUNTIME_ANNOTATION_TYPES = (datetime, UUID) class BoardBase(SQLModel): diff --git a/backend/app/schemas/gateway_api.py b/backend/app/schemas/gateway_api.py index 4cbf32ec..9d5b1b8b 100644 --- a/backend/app/schemas/gateway_api.py +++ b/backend/app/schemas/gateway_api.py @@ -4,7 +4,9 @@ from __future__ import annotations from sqlmodel import SQLModel -from app.schemas.common import NonEmptyStr # noqa: TCH001 +from app.schemas.common import NonEmptyStr + +RUNTIME_ANNOTATION_TYPES = (NonEmptyStr,) class GatewaySessionMessageRequest(SQLModel): diff --git a/backend/app/schemas/gateway_coordination.py b/backend/app/schemas/gateway_coordination.py index 9ff80b9f..42edcf9c 100644 --- a/backend/app/schemas/gateway_coordination.py +++ b/backend/app/schemas/gateway_coordination.py @@ -3,11 +3,13 @@ from __future__ import annotations from typing import Literal -from uuid import UUID # noqa: TCH003 +from uuid import UUID from sqlmodel import Field, SQLModel -from app.schemas.common import NonEmptyStr # noqa: TCH001 +from app.schemas.common import NonEmptyStr + +RUNTIME_ANNOTATION_TYPES = (UUID, NonEmptyStr) def _lead_reply_tags() -> list[str]: diff --git a/backend/app/schemas/gateways.py b/backend/app/schemas/gateways.py index d66a1536..df424eb7 100644 --- a/backend/app/schemas/gateways.py +++ b/backend/app/schemas/gateways.py @@ -2,12 +2,14 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 -from uuid import UUID # noqa: TCH003 +from datetime import datetime +from uuid import UUID from pydantic import field_validator from sqlmodel import Field, SQLModel +RUNTIME_ANNOTATION_TYPES = (datetime, UUID) + class GatewayBase(SQLModel): """Shared gateway fields used across create/read payloads.""" diff --git a/backend/app/schemas/metrics.py b/backend/app/schemas/metrics.py index c379c7cb..bb476f58 100644 --- a/backend/app/schemas/metrics.py +++ b/backend/app/schemas/metrics.py @@ -2,11 +2,13 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from typing import Literal from sqlmodel import SQLModel +RUNTIME_ANNOTATION_TYPES = (datetime,) + class DashboardSeriesPoint(SQLModel): """Single numeric time-series point.""" diff --git a/backend/app/schemas/organizations.py b/backend/app/schemas/organizations.py index 8afc4a60..0a17abde 100644 --- a/backend/app/schemas/organizations.py +++ b/backend/app/schemas/organizations.py @@ -2,11 +2,13 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 -from uuid import UUID # noqa: TCH003 +from datetime import datetime +from uuid import UUID from sqlmodel import Field, SQLModel +RUNTIME_ANNOTATION_TYPES = (datetime, UUID) + class OrganizationRead(SQLModel): """Organization payload returned by read endpoints.""" diff --git a/backend/app/schemas/tasks.py b/backend/app/schemas/tasks.py index a871e88f..6ae534d5 100644 --- a/backend/app/schemas/tasks.py +++ b/backend/app/schemas/tasks.py @@ -2,17 +2,20 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 +from datetime import datetime from typing import Literal, Self -from uuid import UUID # noqa: TCH003 +from uuid import UUID from pydantic import field_validator, model_validator from sqlmodel import Field, SQLModel -from app.schemas.common import NonEmptyStr # noqa: TCH001 +from app.schemas.common import NonEmptyStr TaskStatus = Literal["inbox", "in_progress", "review", "done"] STATUS_REQUIRED_ERROR = "status is required" +# Keep these symbols as runtime globals so Pydantic can resolve +# deferred annotations reliably. +RUNTIME_ANNOTATION_TYPES = (datetime, UUID, NonEmptyStr) class TaskBase(SQLModel): diff --git a/backend/app/schemas/users.py b/backend/app/schemas/users.py index 95fb7d69..a6e8268c 100644 --- a/backend/app/schemas/users.py +++ b/backend/app/schemas/users.py @@ -2,10 +2,12 @@ from __future__ import annotations -from uuid import UUID # noqa: TCH003 +from uuid import UUID from sqlmodel import SQLModel +RUNTIME_ANNOTATION_TYPES = (UUID,) + class UserBase(SQLModel): """Common user profile fields shared across user payload schemas.""" diff --git a/backend/app/schemas/view_models.py b/backend/app/schemas/view_models.py index c4766b45..4c14faf0 100644 --- a/backend/app/schemas/view_models.py +++ b/backend/app/schemas/view_models.py @@ -2,18 +2,28 @@ from __future__ import annotations -from datetime import datetime # noqa: TCH003 -from uuid import UUID # noqa: TCH003 +from datetime import datetime +from uuid import UUID from sqlmodel import Field, SQLModel -from app.schemas.agents import AgentRead # noqa: TCH001 -from app.schemas.approvals import ApprovalRead # noqa: TCH001 -from app.schemas.board_groups import BoardGroupRead # noqa: TCH001 -from app.schemas.board_memory import BoardMemoryRead # noqa: TCH001 -from app.schemas.boards import BoardRead # noqa: TCH001 +from app.schemas.agents import AgentRead +from app.schemas.approvals import ApprovalRead +from app.schemas.board_groups import BoardGroupRead +from app.schemas.board_memory import BoardMemoryRead +from app.schemas.boards import BoardRead from app.schemas.tasks import TaskRead +RUNTIME_ANNOTATION_TYPES = ( + datetime, + UUID, + AgentRead, + ApprovalRead, + BoardGroupRead, + BoardMemoryRead, + BoardRead, +) + class TaskCardRead(TaskRead): """Task read model enriched with assignee and approval counters.""" diff --git a/backend/app/services/agent_provisioning.py b/backend/app/services/agent_provisioning.py index 21eed80b..a89add6f 100644 --- a/backend/app/services/agent_provisioning.py +++ b/backend/app/services/agent_provisioning.py @@ -6,6 +6,7 @@ import hashlib import json import re from contextlib import suppress +from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Any, cast from uuid import uuid4 @@ -88,6 +89,36 @@ MAIN_TEMPLATE_MAP = { } +@dataclass(frozen=True, slots=True) +class ProvisionOptions: + """Toggles controlling provisioning write/reset behavior.""" + + action: str = "provision" + force_bootstrap: bool = False + reset_session: bool = False + + +@dataclass(frozen=True, slots=True) +class AgentProvisionRequest: + """Inputs required to provision a board-scoped agent.""" + + board: Board + gateway: Gateway + auth_token: str + user: User | None + options: ProvisionOptions = field(default_factory=ProvisionOptions) + + +@dataclass(frozen=True, slots=True) +class MainAgentProvisionRequest: + """Inputs required to provision a gateway main agent.""" + + gateway: Gateway + auth_token: str + user: User | None + options: ProvisionOptions = field(default_factory=ProvisionOptions) + + def _repo_root() -> Path: return Path(__file__).resolve().parents[3] @@ -114,31 +145,48 @@ def _agent_id_from_session_key(session_key: str | None) -> str | None: return agent_id or None -def _extract_agent_id(payload: object) -> str | None: # noqa: C901 - def _from_list(items: object) -> str | None: - if not isinstance(items, list): - return None - for item in items: - if isinstance(item, str) and item.strip(): - return item.strip() - if not isinstance(item, dict): - continue - for key in ("id", "agentId", "agent_id"): - raw = item.get(key) - if isinstance(raw, str) and raw.strip(): - return raw.strip() +def _clean_str(value: object) -> str | None: + if isinstance(value, str) and value.strip(): + return value.strip() + return None + + +def _extract_agent_id_from_item(item: object) -> str | None: + if isinstance(item, str): + return _clean_str(item) + if not isinstance(item, dict): return None + for key in ("id", "agentId", "agent_id"): + agent_id = _clean_str(item.get(key)) + if agent_id: + return agent_id + return None + + +def _extract_agent_id_from_list(items: object) -> str | None: + if not isinstance(items, list): + return None + for item in items: + agent_id = _extract_agent_id_from_item(item) + if agent_id: + return agent_id + return None + + +def _extract_agent_id(payload: object) -> str | None: + default_keys = ("defaultId", "default_id", "defaultAgentId", "default_agent_id") + collection_keys = ("agents", "items", "list", "data") if isinstance(payload, list): - return _from_list(payload) + return _extract_agent_id_from_list(payload) if not isinstance(payload, dict): return None - for key in ("defaultId", "default_id", "defaultAgentId", "default_agent_id"): - raw = payload.get(key) - if isinstance(raw, str) and raw.strip(): - return raw.strip() - for key in ("agents", "items", "list", "data"): - agent_id = _from_list(payload.get(key)) + for key in default_keys: + agent_id = _clean_str(payload.get(key)) + if agent_id: + return agent_id + for key in collection_keys: + agent_id = _extract_agent_id_from_list(payload.get(key)) if agent_id: return agent_id return None @@ -523,42 +571,44 @@ async def _patch_gateway_agent_list( await openclaw_call("config.patch", params, config=config) -async def patch_gateway_agent_heartbeats( # noqa: C901 - gateway: Gateway, - *, - entries: list[tuple[str, str, dict[str, Any]]], -) -> None: - """Patch multiple agent heartbeat configs in a single gateway config.patch call. - - Each entry is (agent_id, workspace_path, heartbeat_dict). - """ - if not gateway.url: - msg = "Gateway url is required" - raise OpenClawGatewayError(msg) - config = GatewayClientConfig(url=gateway.url, token=gateway.token) +async def _gateway_config_agent_list( + config: GatewayClientConfig, +) -> tuple[str | None, list[object]]: cfg = await openclaw_call("config.get", config=config) if not isinstance(cfg, dict): msg = "config.get returned invalid payload" raise OpenClawGatewayError(msg) - base_hash = cfg.get("hash") + data = cfg.get("config") or cfg.get("parsed") or {} if not isinstance(data, dict): msg = "config.get returned invalid config" raise OpenClawGatewayError(msg) + agents_section = data.get("agents") or {} - lst = agents_section.get("list") or [] - if not isinstance(lst, list): + agents_list = agents_section.get("list") or [] + if not isinstance(agents_list, list): msg = "config agents.list is not a list" raise OpenClawGatewayError(msg) + return cfg.get("hash"), agents_list - entry_by_id: dict[str, tuple[str, dict[str, Any]]] = { + +def _heartbeat_entry_map( + entries: list[tuple[str, str, dict[str, Any]]], +) -> dict[str, tuple[str, dict[str, Any]]]: + return { agent_id: (workspace_path, heartbeat) for agent_id, workspace_path, heartbeat in entries } + +def _updated_agent_list( + raw_list: list[object], + entry_by_id: dict[str, tuple[str, dict[str, Any]]], +) -> list[object]: updated_ids: set[str] = set() - new_list: list[dict[str, Any]] = [] - for raw_entry in lst: + new_list: list[object] = [] + + for raw_entry in raw_list: if not isinstance(raw_entry, dict): new_list.append(raw_entry) continue @@ -566,6 +616,7 @@ async def patch_gateway_agent_heartbeats( # noqa: C901 if not isinstance(agent_id, str) or agent_id not in entry_by_id: new_list.append(raw_entry) continue + workspace_path, heartbeat = entry_by_id[agent_id] new_entry = dict(raw_entry) new_entry["workspace"] = workspace_path @@ -580,6 +631,26 @@ async def patch_gateway_agent_heartbeats( # noqa: C901 {"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat}, ) + return new_list + + +async def patch_gateway_agent_heartbeats( + gateway: Gateway, + *, + entries: list[tuple[str, str, dict[str, Any]]], +) -> None: + """Patch multiple agent heartbeat configs in a single gateway config.patch call. + + Each entry is (agent_id, workspace_path, heartbeat_dict). + """ + if not gateway.url: + msg = "Gateway url is required" + raise OpenClawGatewayError(msg) + config = GatewayClientConfig(url=gateway.url, token=gateway.token) + base_hash, raw_list = await _gateway_config_agent_list(config) + entry_by_id = _heartbeat_entry_map(entries) + new_list = _updated_agent_list(raw_list, entry_by_id) + patch = {"agents": {"list": new_list}} params = {"raw": json.dumps(patch)} if base_hash: @@ -656,18 +727,52 @@ async def _get_gateway_agent_entry( return None -async def provision_agent( # noqa: C901, PLR0912, PLR0913 - agent: Agent, - board: Board, - gateway: Gateway, - auth_token: str, - user: User | None, +def _should_include_bootstrap( *, - action: str = "provision", - force_bootstrap: bool = False, - reset_session: bool = False, + action: str, + force_bootstrap: bool, + existing_files: dict[str, dict[str, Any]], +) -> bool: + if action != "update" or force_bootstrap: + return True + if not existing_files: + return False + entry = existing_files.get("BOOTSTRAP.md") + return not (entry and entry.get("missing") is True) + + +async def _set_agent_files( + *, + agent_id: str, + rendered: dict[str, str], + existing_files: dict[str, dict[str, Any]], + client_config: GatewayClientConfig, +) -> None: + for name, content in rendered.items(): + if content == "": + continue + if name in PRESERVE_AGENT_EDITABLE_FILES: + entry = existing_files.get(name) + if entry and entry.get("missing") is not True: + continue + try: + await openclaw_call( + "agents.files.set", + {"agentId": agent_id, "name": name, "content": content}, + config=client_config, + ) + except OpenClawGatewayError as exc: + if "unsupported file" in str(exc).lower(): + continue + raise + + +async def provision_agent( + agent: Agent, + request: AgentProvisionRequest, ) -> None: """Provision or update a regular board agent workspace.""" + gateway = request.gateway if not gateway.url: return if not gateway.workspace_root: @@ -682,18 +787,21 @@ async def provision_agent( # noqa: C901, PLR0912, PLR0913 heartbeat = _heartbeat_config(agent) await _patch_gateway_agent_list(agent_id, workspace_path, heartbeat, client_config) - context = _build_context(agent, board, gateway, auth_token, user) + context = _build_context( + agent, + request.board, + gateway, + request.auth_token, + request.user, + ) supported = set(await _supported_gateway_files(client_config)) supported.update({"USER.md", "SELF.md", "AUTONOMY.md"}) existing_files = await _gateway_agent_files_index(agent_id, client_config) - include_bootstrap = True - if action == "update" and not force_bootstrap: - if not existing_files: - include_bootstrap = False - else: - entry = existing_files.get("BOOTSTRAP.md") - if entry and entry.get("missing") is True: - include_bootstrap = False + include_bootstrap = _should_include_bootstrap( + action=request.options.action, + force_bootstrap=request.options.force_bootstrap, + existing_files=existing_files, + ) rendered = _render_agent_files( context, @@ -710,41 +818,22 @@ async def provision_agent( # noqa: C901, PLR0912, PLR0913 with suppress(OSError): # Local workspace may not be writable/available; fall back to gateway API. _ensure_workspace_file(workspace_path, name, content, overwrite=False) - for name, content in rendered.items(): - if content == "": - continue - if name in PRESERVE_AGENT_EDITABLE_FILES: - # Never overwrite; only provision if missing. - entry = existing_files.get(name) - if entry and entry.get("missing") is not True: - continue - try: - await openclaw_call( - "agents.files.set", - {"agentId": agent_id, "name": name, "content": content}, - config=client_config, - ) - except OpenClawGatewayError as exc: - # Gateways may restrict file names. Skip unsupported files rather than - # failing provisioning for the entire agent. - if "unsupported file" in str(exc).lower(): - continue - raise - if reset_session: + await _set_agent_files( + agent_id=agent_id, + rendered=rendered, + existing_files=existing_files, + client_config=client_config, + ) + if request.options.reset_session: await _reset_session(session_key, client_config) -async def provision_main_agent( # noqa: C901, PLR0912, PLR0913 +async def provision_main_agent( agent: Agent, - gateway: Gateway, - auth_token: str, - user: User | None, - *, - action: str = "provision", - force_bootstrap: bool = False, - reset_session: bool = False, + request: MainAgentProvisionRequest, ) -> None: """Provision or update the gateway main agent workspace.""" + gateway = request.gateway if not gateway.url: return if not gateway.main_session_key: @@ -763,18 +852,15 @@ async def provision_main_agent( # noqa: C901, PLR0912, PLR0913 msg = "Unable to resolve gateway main agent id" raise OpenClawGatewayError(msg) - context = _build_main_context(agent, gateway, auth_token, user) + context = _build_main_context(agent, gateway, request.auth_token, request.user) supported = set(await _supported_gateway_files(client_config)) supported.update({"USER.md", "SELF.md", "AUTONOMY.md"}) existing_files = await _gateway_agent_files_index(agent_id, client_config) - include_bootstrap = action != "update" or force_bootstrap - if action == "update" and not force_bootstrap: - if not existing_files: - include_bootstrap = False - else: - entry = existing_files.get("BOOTSTRAP.md") - if entry and entry.get("missing") is True: - include_bootstrap = False + include_bootstrap = _should_include_bootstrap( + action=request.options.action, + force_bootstrap=request.options.force_bootstrap, + existing_files=existing_files, + ) rendered = _render_agent_files( context, @@ -783,24 +869,13 @@ async def provision_main_agent( # noqa: C901, PLR0912, PLR0913 include_bootstrap=include_bootstrap, template_overrides=MAIN_TEMPLATE_MAP, ) - for name, content in rendered.items(): - if content == "": - continue - if name in PRESERVE_AGENT_EDITABLE_FILES: - entry = existing_files.get(name) - if entry and entry.get("missing") is not True: - continue - try: - await openclaw_call( - "agents.files.set", - {"agentId": agent_id, "name": name, "content": content}, - config=client_config, - ) - except OpenClawGatewayError as exc: - if "unsupported file" in str(exc).lower(): - continue - raise - if reset_session: + await _set_agent_files( + agent_id=agent_id, + rendered=rendered, + existing_files=existing_files, + client_config=client_config, + ) + if request.options.reset_session: await _reset_session(gateway.main_session_key, client_config) diff --git a/backend/app/services/board_leads.py b/backend/app/services/board_leads.py index 9f2028a6..ed5289a2 100644 --- a/backend/app/services/board_leads.py +++ b/backend/app/services/board_leads.py @@ -2,6 +2,7 @@ from __future__ import annotations +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any from sqlmodel import col, select @@ -15,7 +16,12 @@ from app.integrations.openclaw_gateway import ( send_message, ) from app.models.agents import Agent -from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_agent +from app.services.agent_provisioning import ( + DEFAULT_HEARTBEAT_CONFIG, + AgentProvisionRequest, + ProvisionOptions, + provision_agent, +) if TYPE_CHECKING: from sqlmodel.ext.asyncio.session import AsyncSession @@ -35,18 +41,34 @@ def lead_agent_name(_: Board) -> str: return "Lead Agent" -async def ensure_board_lead_agent( # noqa: PLR0913 +@dataclass(frozen=True, slots=True) +class LeadAgentOptions: + """Optional overrides for board-lead provisioning behavior.""" + + agent_name: str | None = None + identity_profile: dict[str, str] | None = None + action: str = "provision" + + +@dataclass(frozen=True, slots=True) +class LeadAgentRequest: + """Inputs required to ensure or provision a board lead agent.""" + + board: Board + gateway: Gateway + config: GatewayClientConfig + user: User | None + options: LeadAgentOptions = field(default_factory=LeadAgentOptions) + + +async def ensure_board_lead_agent( session: AsyncSession, *, - board: Board, - gateway: Gateway, - config: GatewayClientConfig, - user: User | None, - agent_name: str | None = None, - identity_profile: dict[str, str] | None = None, - action: str = "provision", + request: LeadAgentRequest, ) -> tuple[Agent, bool]: """Ensure a board has a lead agent; return `(agent, created)`.""" + board = request.board + config_options = request.options existing = ( await session.exec( select(Agent) @@ -55,7 +77,7 @@ async def ensure_board_lead_agent( # noqa: PLR0913 ) ).first() if existing: - desired_name = agent_name or lead_agent_name(board) + desired_name = config_options.agent_name or lead_agent_name(board) changed = False if existing.name != desired_name: existing.name = desired_name @@ -76,17 +98,17 @@ async def ensure_board_lead_agent( # noqa: PLR0913 "communication_style": "direct, concise, practical", "emoji": ":gear:", } - if identity_profile: + if config_options.identity_profile: merged_identity_profile.update( { key: value.strip() - for key, value in identity_profile.items() + for key, value in config_options.identity_profile.items() if value.strip() }, ) agent = Agent( - name=agent_name or lead_agent_name(board), + name=config_options.agent_name or lead_agent_name(board), status="provisioning", board_id=board.id, is_board_lead=True, @@ -94,7 +116,7 @@ async def ensure_board_lead_agent( # noqa: PLR0913 identity_profile=merged_identity_profile, openclaw_session_id=lead_session_key(board), provision_requested_at=utcnow(), - provision_action=action, + provision_action=config_options.action, ) raw_token = generate_agent_token() agent.agent_token_hash = hash_agent_token(raw_token) @@ -103,11 +125,20 @@ async def ensure_board_lead_agent( # noqa: PLR0913 await session.refresh(agent) try: - await provision_agent(agent, board, gateway, raw_token, user, action=action) + await provision_agent( + agent, + AgentProvisionRequest( + board=board, + gateway=request.gateway, + auth_token=raw_token, + user=request.user, + options=ProvisionOptions(action=config_options.action), + ), + ) if agent.openclaw_session_id: await ensure_session( agent.openclaw_session_id, - config=config, + config=request.config, label=agent.name, ) await send_message( @@ -118,7 +149,7 @@ async def ensure_board_lead_agent( # noqa: PLR0913 "then delete it. Begin heartbeats after startup." ), session_key=agent.openclaw_session_id, - config=config, + config=request.config, deliver=True, ) except OpenClawGatewayError: diff --git a/backend/app/services/souls_directory.py b/backend/app/services/souls_directory.py index 7e9f9e03..d144a482 100644 --- a/backend/app/services/souls_directory.py +++ b/backend/app/services/souls_directory.py @@ -2,9 +2,10 @@ from __future__ import annotations +import re import time -import xml.etree.ElementTree as ET from dataclasses import dataclass +from html import unescape from typing import Final import httpx @@ -14,6 +15,10 @@ SOULS_DIRECTORY_SITEMAP_URL: Final[str] = f"{SOULS_DIRECTORY_BASE_URL}/sitemap.x _SITEMAP_TTL_SECONDS: Final[int] = 60 * 60 _SOUL_URL_MIN_PARTS: Final[int] = 6 +_LOC_PATTERN: Final[re.Pattern[str]] = re.compile( + r"<(?:[A-Za-z0-9_]+:)?loc>(.*?)", + flags=re.IGNORECASE | re.DOTALL, +) @dataclass(frozen=True, slots=True) @@ -36,17 +41,10 @@ class SoulRef: def _parse_sitemap_soul_refs(sitemap_xml: str) -> list[SoulRef]: """Parse sitemap XML and extract valid souls.directory handle/slug refs.""" - try: - # Souls sitemap is fetched from a known trusted host in this service flow. - root = ET.fromstring(sitemap_xml) # noqa: S314 - except ET.ParseError: - return [] - - # Handle both namespaced and non-namespaced sitemap XML. + # Extract values without XML entity expansion. urls = [ - loc.text.strip() - for loc in root.iter() - if loc.tag.endswith("loc") and loc.text + unescape(match.group(1)).strip() + for match in _LOC_PATTERN.finditer(sitemap_xml) ] refs: list[SoulRef] = [] diff --git a/backend/app/services/template_sync.py b/backend/app/services/template_sync.py index 7bb6e9e4..ac9e3508 100644 --- a/backend/app/services/template_sync.py +++ b/backend/app/services/template_sync.py @@ -28,7 +28,13 @@ from app.models.boards import Board from app.models.gateways import Gateway from app.models.users import User from app.schemas.gateways import GatewayTemplatesSyncError, GatewayTemplatesSyncResult -from app.services.agent_provisioning import provision_agent, provision_main_agent +from app.services.agent_provisioning import ( + AgentProvisionRequest, + MainAgentProvisionRequest, + ProvisionOptions, + provision_agent, + provision_main_agent, +) _TOOLS_KV_RE = re.compile(r"^(?P[A-Z0-9_]+)=(?P.*)$") SESSION_KEY_PARTS_MIN = 2 @@ -480,13 +486,17 @@ async def _sync_one_agent( async def _do_provision() -> None: await provision_agent( agent, - board, - ctx.gateway, - auth_token, - ctx.options.user, - action="update", - force_bootstrap=ctx.options.force_bootstrap, - reset_session=ctx.options.reset_sessions, + AgentProvisionRequest( + board=board, + gateway=ctx.gateway, + auth_token=auth_token, + user=ctx.options.user, + options=ProvisionOptions( + action="update", + force_bootstrap=ctx.options.force_bootstrap, + reset_session=ctx.options.reset_sessions, + ), + ), ) await _with_gateway_retry(_do_provision, backoff=ctx.backoff) @@ -564,12 +574,16 @@ async def _sync_main_agent( async def _do_provision_main() -> None: await provision_main_agent( main_agent, - ctx.gateway, - token, - ctx.options.user, - action="update", - force_bootstrap=ctx.options.force_bootstrap, - reset_session=ctx.options.reset_sessions, + MainAgentProvisionRequest( + gateway=ctx.gateway, + auth_token=token, + user=ctx.options.user, + options=ProvisionOptions( + action="update", + force_bootstrap=ctx.options.force_bootstrap, + reset_session=ctx.options.reset_sessions, + ), + ), ) await _with_gateway_retry(_do_provision_main, backoff=ctx.backoff) diff --git a/backend/migrations/env.py b/backend/migrations/env.py index 1561876e..a5f7cd84 100644 --- a/backend/migrations/env.py +++ b/backend/migrations/env.py @@ -2,6 +2,7 @@ from __future__ import annotations +import importlib import sys from logging.config import fileConfig from pathlib import Path @@ -14,8 +15,8 @@ PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.append(str(PROJECT_ROOT)) -from app import models # noqa: E402,F401 -from app.core.config import settings # noqa: E402 +importlib.import_module("app.models") +settings = importlib.import_module("app.core.config").settings config = context.config configure_logger = config.attributes.get("configure_logger", True) diff --git a/backend/migrations/versions/658dca8f4a11_init.py b/backend/migrations/versions/658dca8f4a11_init.py index df5f21f6..c137c271 100644 --- a/backend/migrations/versions/658dca8f4a11_init.py +++ b/backend/migrations/versions/658dca8f4a11_init.py @@ -8,7 +8,6 @@ Create Date: 2026-02-09 00:41:55.760624 from __future__ import annotations -# ruff: noqa: INP001 import sqlalchemy as sa import sqlmodel from alembic import op @@ -20,8 +19,17 @@ branch_labels = None depends_on = None -def upgrade() -> None: # noqa: PLR0915 + + +def upgrade() -> None: """Create initial schema objects.""" + _upgrade_part_1() + _upgrade_part_2() + _upgrade_part_3() + _upgrade_part_4() + + +def _upgrade_part_1() -> None: # ### commands auto generated by Alembic - please adjust! ### op.create_table( "organizations", @@ -183,6 +191,9 @@ def upgrade() -> None: # noqa: PLR0915 op.f("ix_boards_organization_id"), "boards", ["organization_id"], unique=False, ) op.create_index(op.f("ix_boards_slug"), "boards", ["slug"], unique=False) + + +def _upgrade_part_2() -> None: op.create_table( "organization_invites", sa.Column("id", sa.Uuid(), nullable=False), @@ -366,6 +377,9 @@ def upgrade() -> None: # noqa: PLR0915 unique=False, ) op.create_index(op.f("ix_agents_status"), "agents", ["status"], unique=False) + + +def _upgrade_part_3() -> None: op.create_table( "board_memory", sa.Column("id", sa.Uuid(), nullable=False), @@ -532,6 +546,9 @@ def upgrade() -> None: # noqa: PLR0915 ) op.create_index(op.f("ix_tasks_priority"), "tasks", ["priority"], unique=False) op.create_index(op.f("ix_tasks_status"), "tasks", ["status"], unique=False) + + +def _upgrade_part_4() -> None: op.create_table( "activity_events", sa.Column("id", sa.Uuid(), nullable=False), @@ -686,8 +703,14 @@ def upgrade() -> None: # noqa: PLR0915 # ### end Alembic commands ### -def downgrade() -> None: # noqa: PLR0915 +def downgrade() -> None: """Drop initial schema objects.""" + _downgrade_part_1() + _downgrade_part_2() + _downgrade_part_3() + + +def _downgrade_part_1() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_index( op.f("ix_task_fingerprints_fingerprint_hash"), table_name="task_fingerprints", @@ -745,6 +768,9 @@ def downgrade() -> None: # noqa: PLR0915 op.drop_index(op.f("ix_board_memory_is_chat"), table_name="board_memory") op.drop_index(op.f("ix_board_memory_board_id"), table_name="board_memory") op.drop_table("board_memory") + + +def _downgrade_part_2() -> None: op.drop_index(op.f("ix_agents_status"), table_name="agents") op.drop_index(op.f("ix_agents_provision_confirm_token_hash"), table_name="agents") op.drop_index(op.f("ix_agents_provision_action"), table_name="agents") @@ -795,6 +821,9 @@ def downgrade() -> None: # noqa: PLR0915 op.drop_index(op.f("ix_boards_board_type"), table_name="boards") op.drop_index(op.f("ix_boards_board_group_id"), table_name="boards") op.drop_table("boards") + + +def _downgrade_part_3() -> None: op.drop_index( op.f("ix_board_group_memory_is_chat"), table_name="board_group_memory", ) diff --git a/backend/migrations/versions/__init__.py b/backend/migrations/versions/__init__.py new file mode 100644 index 00000000..dbb17290 --- /dev/null +++ b/backend/migrations/versions/__init__.py @@ -0,0 +1 @@ +"""Alembic migration version modules.""" diff --git a/backend/scripts/export_openapi.py b/backend/scripts/export_openapi.py index ea842b8f..aad62bc1 100644 --- a/backend/scripts/export_openapi.py +++ b/backend/scripts/export_openapi.py @@ -9,11 +9,11 @@ from pathlib import Path BACKEND_ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(BACKEND_ROOT)) -from app.main import app # noqa: E402 - def main() -> None: """Generate `openapi.json` from the FastAPI app definition.""" + from app.main import app + # Importing the FastAPI app does not run lifespan hooks, # so this does not require a DB. out_path = BACKEND_ROOT / "openapi.json" diff --git a/backend/scripts/seed_demo.py b/backend/scripts/seed_demo.py index 2866e337..b3b922cc 100644 --- a/backend/scripts/seed_demo.py +++ b/backend/scripts/seed_demo.py @@ -10,15 +10,15 @@ from uuid import uuid4 BACKEND_ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(BACKEND_ROOT)) -from app.db.session import async_session_maker, init_db # noqa: E402 -from app.models.agents import Agent # noqa: E402 -from app.models.boards import Board # noqa: E402 -from app.models.gateways import Gateway # noqa: E402 -from app.models.users import User # noqa: E402 - async def run() -> None: """Populate the local database with a demo gateway, board, user, and agent.""" + from app.db.session import async_session_maker, init_db + from app.models.agents import Agent + from app.models.boards import Board + from app.models.gateways import Gateway + from app.models.users import User + await init_db() async with async_session_maker() as session: demo_workspace_root = BACKEND_ROOT / ".tmp" / "openclaw-demo" diff --git a/backend/scripts/sync_gateway_templates.py b/backend/scripts/sync_gateway_templates.py index a6deb0e8..fe07bff1 100644 --- a/backend/scripts/sync_gateway_templates.py +++ b/backend/scripts/sync_gateway_templates.py @@ -1,4 +1,3 @@ -# ruff: noqa: INP001 """CLI script to sync template files into gateway agent workspaces.""" from __future__ import annotations @@ -12,10 +11,6 @@ from uuid import UUID BACKEND_ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(BACKEND_ROOT)) -from app.db.session import async_session_maker # noqa: E402 -from app.models.gateways import Gateway # noqa: E402 -from app.services.template_sync import sync_gateway_templates # noqa: E402 - def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( @@ -59,6 +54,13 @@ def _parse_args() -> argparse.Namespace: async def _run() -> int: + from app.db.session import async_session_maker + from app.models.gateways import Gateway + from app.services.template_sync import ( + GatewayTemplateSyncOptions, + sync_gateway_templates, + ) + args = _parse_args() gateway_id = UUID(args.gateway_id) board_id = UUID(args.board_id) if args.board_id else None @@ -72,12 +74,14 @@ async def _run() -> int: result = await sync_gateway_templates( session, gateway, - user=None, - include_main=bool(args.include_main), - reset_sessions=bool(args.reset_sessions), - rotate_tokens=bool(args.rotate_tokens), - force_bootstrap=bool(args.force_bootstrap), - board_id=board_id, + options=GatewayTemplateSyncOptions( + user=None, + include_main=bool(args.include_main), + reset_sessions=bool(args.reset_sessions), + rotate_tokens=bool(args.rotate_tokens), + force_bootstrap=bool(args.force_bootstrap), + board_id=board_id, + ), ) sys.stdout.write(f"gateway_id={result.gateway_id}\n")