diff --git a/backend/app/services/agent_provisioning.py b/backend/app/services/agent_provisioning.py index 1a54a062..faab1ced 100644 --- a/backend/app/services/agent_provisioning.py +++ b/backend/app/services/agent_provisioning.py @@ -86,6 +86,49 @@ def _slugify(value: str) -> str: return slug or uuid4().hex +def _agent_id_from_session_key(session_key: str | None) -> str | None: + value = (session_key or "").strip() + if not value: + return None + if not value.startswith("agent:"): + return None + parts = value.split(":") + if len(parts) < 2: + return None + agent_id = parts[1].strip() + return agent_id or None + + +def _extract_agent_id(payload: object) -> str | None: + def _from_list(items: object) -> str | None: + if not isinstance(items, list): + return None + for item in items: + if isinstance(item, str) and item.strip(): + return item.strip() + if not isinstance(item, dict): + continue + for key in ("id", "agentId", "agent_id"): + raw = item.get(key) + if isinstance(raw, str) and raw.strip(): + return raw.strip() + return None + + if isinstance(payload, list): + return _from_list(payload) + if not isinstance(payload, dict): + return None + for key in ("defaultId", "default_id", "defaultAgentId", "default_agent_id"): + raw = payload.get(key) + if isinstance(raw, str) and raw.strip(): + return raw.strip() + for key in ("agents", "items", "list", "data"): + agent_id = _from_list(payload.get(key)) + if agent_id: + return agent_id + return None + + def _agent_key(agent: Agent) -> str: session_key = agent.openclaw_session_id or "" if session_key.startswith("agent:"): @@ -383,24 +426,18 @@ def _render_agent_files( async def _gateway_default_agent_id( config: GatewayClientConfig, + *, + fallback_session_key: str | None = None, ) -> str | None: try: payload = await openclaw_call("agents.list", config=config) except OpenClawGatewayError: - return None - if not isinstance(payload, dict): - return None - default_id = payload.get("defaultId") or payload.get("default_id") - if isinstance(default_id, str) and default_id: - return default_id - agents = payload.get("agents") or [] - if isinstance(agents, list) and agents: - first = agents[0] - if isinstance(first, dict): - agent_id = first.get("id") - if isinstance(agent_id, str) and agent_id: - return agent_id - return None + return _agent_id_from_session_key(fallback_session_key) + + agent_id = _extract_agent_id(payload) + if agent_id: + return agent_id + return _agent_id_from_session_key(fallback_session_key) async def _patch_gateway_agent_list( @@ -585,7 +622,10 @@ async def provision_main_agent( client_config = GatewayClientConfig(url=gateway.url, token=gateway.token) await ensure_session(gateway.main_session_key, config=client_config, label="Main Agent") - agent_id = await _gateway_default_agent_id(client_config) + agent_id = await _gateway_default_agent_id( + client_config, + fallback_session_key=gateway.main_session_key, + ) if not agent_id: raise OpenClawGatewayError("Unable to resolve gateway main agent id") diff --git a/backend/app/services/template_sync.py b/backend/app/services/template_sync.py index 825a9f65..64fe2b52 100644 --- a/backend/app/services/template_sync.py +++ b/backend/app/services/template_sync.py @@ -1,6 +1,9 @@ from __future__ import annotations +import asyncio import re +from collections.abc import Awaitable, Callable +from typing import TypeVar from uuid import UUID, uuid4 from sqlmodel import col, select @@ -19,12 +22,94 @@ from app.services.agent_provisioning import provision_agent, provision_main_agen _TOOLS_KV_RE = re.compile(r"^(?P[A-Z0-9_]+)=(?P.*)$") +T = TypeVar("T") + def _slugify(value: str) -> str: slug = re.sub(r"[^a-z0-9]+", "-", value.lower()).strip("-") return slug or uuid4().hex +def _is_transient_gateway_error(exc: Exception) -> bool: + if not isinstance(exc, OpenClawGatewayError): + return False + message = str(exc).lower() + if not message: + return False + if "unsupported file" in message: + return False + if "received 1012" in message or "service restart" in message: + return True + if "http 503" in message or ("503" in message and "websocket" in message): + return True + if "temporar" in message: + return True + if "timeout" in message or "timed out" in message: + return True + if "connection closed" in message or "connection reset" in message: + return True + return False + + +async def _with_gateway_retry( + fn: Callable[[], Awaitable[T]], + *, + attempts: int = 3, + base_delay_s: float = 0.75, +) -> T: + for attempt in range(attempts): + try: + return await fn() + except Exception as exc: + if attempt >= attempts - 1 or not _is_transient_gateway_error(exc): + raise + await asyncio.sleep(base_delay_s * (2**attempt)) + raise AssertionError("unreachable") + + +def _agent_id_from_session_key(session_key: str | None) -> str | None: + value = (session_key or "").strip() + if not value: + return None + if not value.startswith("agent:"): + return None + parts = value.split(":") + if len(parts) < 2: + return None + agent_id = parts[1].strip() + return agent_id or None + + +def _extract_agent_id(payload: object) -> str | None: + def _from_list(items: object) -> str | None: + if not isinstance(items, list): + return None + for item in items: + if isinstance(item, str) and item.strip(): + return item.strip() + if not isinstance(item, dict): + continue + for key in ("id", "agentId", "agent_id"): + raw = item.get(key) + if isinstance(raw, str) and raw.strip(): + return raw.strip() + return None + + if isinstance(payload, list): + return _from_list(payload) + if not isinstance(payload, dict): + return None + for key in ("defaultId", "default_id", "defaultAgentId", "default_agent_id"): + raw = payload.get(key) + if isinstance(raw, str) and raw.strip(): + return raw.strip() + for key in ("agents", "items", "list", "data"): + agent_id = _from_list(payload.get(key)) + if agent_id: + return agent_id + return None + + def _gateway_agent_id(agent: Agent) -> str: session_key = agent.openclaw_session_id or "" if session_key.startswith("agent:"): @@ -94,24 +179,34 @@ async def _get_existing_auth_token( return token or None -async def _gateway_default_agent_id(config: GatewayClientConfig) -> str | None: - try: - payload = await openclaw_call("agents.list", config=config) - except OpenClawGatewayError: - return None - if not isinstance(payload, dict): - return None - default_id = payload.get("defaultId") or payload.get("default_id") - if isinstance(default_id, str) and default_id: - return default_id - agents = payload.get("agents") or [] - if isinstance(agents, list) and agents: - first = agents[0] - if isinstance(first, dict): - agent_id = first.get("id") - if isinstance(agent_id, str) and agent_id: +async def _gateway_default_agent_id( + config: GatewayClientConfig, + *, + fallback_session_key: str | None = None, +) -> str | None: + last_error: OpenClawGatewayError | None = None + # Gateways may reject WS connects transiently under load (HTTP 503). + for attempt in range(3): + try: + payload = await openclaw_call("agents.list", config=config) + agent_id = _extract_agent_id(payload) + if agent_id: return agent_id - return None + break + except OpenClawGatewayError as exc: + last_error = exc + message = str(exc).lower() + if ( + "503" not in message + and "temporar" not in message + and "rejected" not in message + and "timeout" not in message + ): + break + await asyncio.sleep(0.5 * (2**attempt)) + + _ = last_error + return _agent_id_from_session_key(fallback_session_key) async def sync_gateway_templates( @@ -226,16 +321,19 @@ async def sync_gateway_templates( ) try: - await provision_agent( - agent, - board, - gateway, - auth_token, - user, - action="update", - force_bootstrap=force_bootstrap, - reset_session=reset_sessions, - ) + async def _do_provision() -> None: + await provision_agent( + agent, + board, + gateway, + auth_token, + user, + action="update", + force_bootstrap=force_bootstrap, + reset_session=reset_sessions, + ) + + await _with_gateway_retry(_do_provision) result.agents_updated += 1 except Exception as exc: # pragma: no cover - gateway/network dependent result.agents_skipped += 1 @@ -262,7 +360,10 @@ async def sync_gateway_templates( ) return result - main_gateway_agent_id = await _gateway_default_agent_id(client_config) + main_gateway_agent_id = await _gateway_default_agent_id( + client_config, + fallback_session_key=gateway.main_session_key, + ) if not main_gateway_agent_id: result.errors.append( GatewayTemplatesSyncError( @@ -277,25 +378,57 @@ async def sync_gateway_templates( agent_gateway_id=main_gateway_agent_id, config=client_config ) if not main_token: - result.errors.append( - GatewayTemplatesSyncError( - agent_id=main_agent.id, - agent_name=main_agent.name, - message="Skipping main agent: unable to read AUTH_TOKEN from TOOLS.md.", + if rotate_tokens: + raw_token = generate_agent_token() + main_agent.agent_token_hash = hash_agent_token(raw_token) + main_agent.updated_at = utcnow() + session.add(main_agent) + await session.commit() + await session.refresh(main_agent) + main_token = raw_token + else: + result.errors.append( + GatewayTemplatesSyncError( + agent_id=main_agent.id, + agent_name=main_agent.name, + message="Skipping main agent: unable to read AUTH_TOKEN from TOOLS.md.", + ) + ) + return result + + if main_agent.agent_token_hash and not verify_agent_token( + main_token, main_agent.agent_token_hash + ): + if rotate_tokens: + raw_token = generate_agent_token() + main_agent.agent_token_hash = hash_agent_token(raw_token) + main_agent.updated_at = utcnow() + session.add(main_agent) + await session.commit() + await session.refresh(main_agent) + main_token = raw_token + else: + result.errors.append( + GatewayTemplatesSyncError( + agent_id=main_agent.id, + agent_name=main_agent.name, + message="Warning: AUTH_TOKEN in TOOLS.md does not match backend token hash (main agent auth may be broken).", + ) ) - ) - return result try: - await provision_main_agent( - main_agent, - gateway, - main_token, - user, - action="update", - force_bootstrap=force_bootstrap, - reset_session=reset_sessions, - ) + async def _do_provision_main() -> None: + await provision_main_agent( + main_agent, + gateway, + main_token, + user, + action="update", + force_bootstrap=force_bootstrap, + reset_session=reset_sessions, + ) + + await _with_gateway_retry(_do_provision_main) result.main_updated = True except Exception as exc: # pragma: no cover - gateway/network dependent result.errors.append( diff --git a/backend/scripts/sync_gateway_templates.py b/backend/scripts/sync_gateway_templates.py index e58f6097..72d44ab7 100644 --- a/backend/scripts/sync_gateway_templates.py +++ b/backend/scripts/sync_gateway_templates.py @@ -2,11 +2,16 @@ from __future__ import annotations import argparse import asyncio +import sys +from pathlib import Path from uuid import UUID -from app.db.session import async_session_maker -from app.models.gateways import Gateway -from app.services.template_sync import sync_gateway_templates +BACKEND_ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(BACKEND_ROOT)) + +from app.db.session import async_session_maker # noqa: E402 +from app.models.gateways import Gateway # noqa: E402 +from app.services.template_sync import sync_gateway_templates # noqa: E402 def _parse_args() -> argparse.Namespace: