feat(agent): enhance agent ID extraction and fallback handling in gateway integration

This commit is contained in:
Abhimanyu Saharan
2026-02-07 16:21:21 +05:30
parent 7ef1f3e2f8
commit 0816fb6cd3
3 changed files with 240 additions and 62 deletions

View File

@@ -86,6 +86,49 @@ def _slugify(value: str) -> str:
return slug or uuid4().hex return slug or uuid4().hex
def _agent_id_from_session_key(session_key: str | None) -> str | None:
value = (session_key or "").strip()
if not value:
return None
if not value.startswith("agent:"):
return None
parts = value.split(":")
if len(parts) < 2:
return None
agent_id = parts[1].strip()
return agent_id or None
def _extract_agent_id(payload: object) -> str | None:
def _from_list(items: object) -> str | None:
if not isinstance(items, list):
return None
for item in items:
if isinstance(item, str) and item.strip():
return item.strip()
if not isinstance(item, dict):
continue
for key in ("id", "agentId", "agent_id"):
raw = item.get(key)
if isinstance(raw, str) and raw.strip():
return raw.strip()
return None
if isinstance(payload, list):
return _from_list(payload)
if not isinstance(payload, dict):
return None
for key in ("defaultId", "default_id", "defaultAgentId", "default_agent_id"):
raw = payload.get(key)
if isinstance(raw, str) and raw.strip():
return raw.strip()
for key in ("agents", "items", "list", "data"):
agent_id = _from_list(payload.get(key))
if agent_id:
return agent_id
return None
def _agent_key(agent: Agent) -> str: def _agent_key(agent: Agent) -> str:
session_key = agent.openclaw_session_id or "" session_key = agent.openclaw_session_id or ""
if session_key.startswith("agent:"): if session_key.startswith("agent:"):
@@ -383,24 +426,18 @@ def _render_agent_files(
async def _gateway_default_agent_id( async def _gateway_default_agent_id(
config: GatewayClientConfig, config: GatewayClientConfig,
*,
fallback_session_key: str | None = None,
) -> str | None: ) -> str | None:
try: try:
payload = await openclaw_call("agents.list", config=config) payload = await openclaw_call("agents.list", config=config)
except OpenClawGatewayError: except OpenClawGatewayError:
return None return _agent_id_from_session_key(fallback_session_key)
if not isinstance(payload, dict):
return None agent_id = _extract_agent_id(payload)
default_id = payload.get("defaultId") or payload.get("default_id") if agent_id:
if isinstance(default_id, str) and default_id: return agent_id
return default_id return _agent_id_from_session_key(fallback_session_key)
agents = payload.get("agents") or []
if isinstance(agents, list) and agents:
first = agents[0]
if isinstance(first, dict):
agent_id = first.get("id")
if isinstance(agent_id, str) and agent_id:
return agent_id
return None
async def _patch_gateway_agent_list( async def _patch_gateway_agent_list(
@@ -585,7 +622,10 @@ async def provision_main_agent(
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token) client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
await ensure_session(gateway.main_session_key, config=client_config, label="Main Agent") await ensure_session(gateway.main_session_key, config=client_config, label="Main Agent")
agent_id = await _gateway_default_agent_id(client_config) agent_id = await _gateway_default_agent_id(
client_config,
fallback_session_key=gateway.main_session_key,
)
if not agent_id: if not agent_id:
raise OpenClawGatewayError("Unable to resolve gateway main agent id") raise OpenClawGatewayError("Unable to resolve gateway main agent id")

View File

@@ -1,6 +1,9 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import re import re
from collections.abc import Awaitable, Callable
from typing import TypeVar
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlmodel import col, select from sqlmodel import col, select
@@ -19,12 +22,94 @@ from app.services.agent_provisioning import provision_agent, provision_main_agen
_TOOLS_KV_RE = re.compile(r"^(?P<key>[A-Z0-9_]+)=(?P<value>.*)$") _TOOLS_KV_RE = re.compile(r"^(?P<key>[A-Z0-9_]+)=(?P<value>.*)$")
T = TypeVar("T")
def _slugify(value: str) -> str: def _slugify(value: str) -> str:
slug = re.sub(r"[^a-z0-9]+", "-", value.lower()).strip("-") slug = re.sub(r"[^a-z0-9]+", "-", value.lower()).strip("-")
return slug or uuid4().hex return slug or uuid4().hex
def _is_transient_gateway_error(exc: Exception) -> bool:
if not isinstance(exc, OpenClawGatewayError):
return False
message = str(exc).lower()
if not message:
return False
if "unsupported file" in message:
return False
if "received 1012" in message or "service restart" in message:
return True
if "http 503" in message or ("503" in message and "websocket" in message):
return True
if "temporar" in message:
return True
if "timeout" in message or "timed out" in message:
return True
if "connection closed" in message or "connection reset" in message:
return True
return False
async def _with_gateway_retry(
fn: Callable[[], Awaitable[T]],
*,
attempts: int = 3,
base_delay_s: float = 0.75,
) -> T:
for attempt in range(attempts):
try:
return await fn()
except Exception as exc:
if attempt >= attempts - 1 or not _is_transient_gateway_error(exc):
raise
await asyncio.sleep(base_delay_s * (2**attempt))
raise AssertionError("unreachable")
def _agent_id_from_session_key(session_key: str | None) -> str | None:
value = (session_key or "").strip()
if not value:
return None
if not value.startswith("agent:"):
return None
parts = value.split(":")
if len(parts) < 2:
return None
agent_id = parts[1].strip()
return agent_id or None
def _extract_agent_id(payload: object) -> str | None:
def _from_list(items: object) -> str | None:
if not isinstance(items, list):
return None
for item in items:
if isinstance(item, str) and item.strip():
return item.strip()
if not isinstance(item, dict):
continue
for key in ("id", "agentId", "agent_id"):
raw = item.get(key)
if isinstance(raw, str) and raw.strip():
return raw.strip()
return None
if isinstance(payload, list):
return _from_list(payload)
if not isinstance(payload, dict):
return None
for key in ("defaultId", "default_id", "defaultAgentId", "default_agent_id"):
raw = payload.get(key)
if isinstance(raw, str) and raw.strip():
return raw.strip()
for key in ("agents", "items", "list", "data"):
agent_id = _from_list(payload.get(key))
if agent_id:
return agent_id
return None
def _gateway_agent_id(agent: Agent) -> str: def _gateway_agent_id(agent: Agent) -> str:
session_key = agent.openclaw_session_id or "" session_key = agent.openclaw_session_id or ""
if session_key.startswith("agent:"): if session_key.startswith("agent:"):
@@ -94,24 +179,34 @@ async def _get_existing_auth_token(
return token or None return token or None
async def _gateway_default_agent_id(config: GatewayClientConfig) -> str | None: async def _gateway_default_agent_id(
try: config: GatewayClientConfig,
payload = await openclaw_call("agents.list", config=config) *,
except OpenClawGatewayError: fallback_session_key: str | None = None,
return None ) -> str | None:
if not isinstance(payload, dict): last_error: OpenClawGatewayError | None = None
return None # Gateways may reject WS connects transiently under load (HTTP 503).
default_id = payload.get("defaultId") or payload.get("default_id") for attempt in range(3):
if isinstance(default_id, str) and default_id: try:
return default_id payload = await openclaw_call("agents.list", config=config)
agents = payload.get("agents") or [] agent_id = _extract_agent_id(payload)
if isinstance(agents, list) and agents: if agent_id:
first = agents[0]
if isinstance(first, dict):
agent_id = first.get("id")
if isinstance(agent_id, str) and agent_id:
return agent_id return agent_id
return None break
except OpenClawGatewayError as exc:
last_error = exc
message = str(exc).lower()
if (
"503" not in message
and "temporar" not in message
and "rejected" not in message
and "timeout" not in message
):
break
await asyncio.sleep(0.5 * (2**attempt))
_ = last_error
return _agent_id_from_session_key(fallback_session_key)
async def sync_gateway_templates( async def sync_gateway_templates(
@@ -226,16 +321,19 @@ async def sync_gateway_templates(
) )
try: try:
await provision_agent( async def _do_provision() -> None:
agent, await provision_agent(
board, agent,
gateway, board,
auth_token, gateway,
user, auth_token,
action="update", user,
force_bootstrap=force_bootstrap, action="update",
reset_session=reset_sessions, force_bootstrap=force_bootstrap,
) reset_session=reset_sessions,
)
await _with_gateway_retry(_do_provision)
result.agents_updated += 1 result.agents_updated += 1
except Exception as exc: # pragma: no cover - gateway/network dependent except Exception as exc: # pragma: no cover - gateway/network dependent
result.agents_skipped += 1 result.agents_skipped += 1
@@ -262,7 +360,10 @@ async def sync_gateway_templates(
) )
return result return result
main_gateway_agent_id = await _gateway_default_agent_id(client_config) main_gateway_agent_id = await _gateway_default_agent_id(
client_config,
fallback_session_key=gateway.main_session_key,
)
if not main_gateway_agent_id: if not main_gateway_agent_id:
result.errors.append( result.errors.append(
GatewayTemplatesSyncError( GatewayTemplatesSyncError(
@@ -277,25 +378,57 @@ async def sync_gateway_templates(
agent_gateway_id=main_gateway_agent_id, config=client_config agent_gateway_id=main_gateway_agent_id, config=client_config
) )
if not main_token: if not main_token:
result.errors.append( if rotate_tokens:
GatewayTemplatesSyncError( raw_token = generate_agent_token()
agent_id=main_agent.id, main_agent.agent_token_hash = hash_agent_token(raw_token)
agent_name=main_agent.name, main_agent.updated_at = utcnow()
message="Skipping main agent: unable to read AUTH_TOKEN from TOOLS.md.", session.add(main_agent)
await session.commit()
await session.refresh(main_agent)
main_token = raw_token
else:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=main_agent.id,
agent_name=main_agent.name,
message="Skipping main agent: unable to read AUTH_TOKEN from TOOLS.md.",
)
)
return result
if main_agent.agent_token_hash and not verify_agent_token(
main_token, main_agent.agent_token_hash
):
if rotate_tokens:
raw_token = generate_agent_token()
main_agent.agent_token_hash = hash_agent_token(raw_token)
main_agent.updated_at = utcnow()
session.add(main_agent)
await session.commit()
await session.refresh(main_agent)
main_token = raw_token
else:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=main_agent.id,
agent_name=main_agent.name,
message="Warning: AUTH_TOKEN in TOOLS.md does not match backend token hash (main agent auth may be broken).",
)
) )
)
return result
try: try:
await provision_main_agent( async def _do_provision_main() -> None:
main_agent, await provision_main_agent(
gateway, main_agent,
main_token, gateway,
user, main_token,
action="update", user,
force_bootstrap=force_bootstrap, action="update",
reset_session=reset_sessions, force_bootstrap=force_bootstrap,
) reset_session=reset_sessions,
)
await _with_gateway_retry(_do_provision_main)
result.main_updated = True result.main_updated = True
except Exception as exc: # pragma: no cover - gateway/network dependent except Exception as exc: # pragma: no cover - gateway/network dependent
result.errors.append( result.errors.append(

View File

@@ -2,11 +2,16 @@ from __future__ import annotations
import argparse import argparse
import asyncio import asyncio
import sys
from pathlib import Path
from uuid import UUID from uuid import UUID
from app.db.session import async_session_maker BACKEND_ROOT = Path(__file__).resolve().parents[1]
from app.models.gateways import Gateway sys.path.insert(0, str(BACKEND_ROOT))
from app.services.template_sync import sync_gateway_templates
from app.db.session import async_session_maker # noqa: E402
from app.models.gateways import Gateway # noqa: E402
from app.services.template_sync import sync_gateway_templates # noqa: E402
def _parse_args() -> argparse.Namespace: def _parse_args() -> argparse.Namespace: