feat(agent): enhance agent ID extraction and fallback handling in gateway integration

This commit is contained in:
Abhimanyu Saharan
2026-02-07 16:21:21 +05:30
parent 7ef1f3e2f8
commit 0816fb6cd3
3 changed files with 240 additions and 62 deletions

View File

@@ -86,6 +86,49 @@ def _slugify(value: str) -> str:
return slug or uuid4().hex
def _agent_id_from_session_key(session_key: str | None) -> str | None:
value = (session_key or "").strip()
if not value:
return None
if not value.startswith("agent:"):
return None
parts = value.split(":")
if len(parts) < 2:
return None
agent_id = parts[1].strip()
return agent_id or None
def _extract_agent_id(payload: object) -> str | None:
def _from_list(items: object) -> str | None:
if not isinstance(items, list):
return None
for item in items:
if isinstance(item, str) and item.strip():
return item.strip()
if not isinstance(item, dict):
continue
for key in ("id", "agentId", "agent_id"):
raw = item.get(key)
if isinstance(raw, str) and raw.strip():
return raw.strip()
return None
if isinstance(payload, list):
return _from_list(payload)
if not isinstance(payload, dict):
return None
for key in ("defaultId", "default_id", "defaultAgentId", "default_agent_id"):
raw = payload.get(key)
if isinstance(raw, str) and raw.strip():
return raw.strip()
for key in ("agents", "items", "list", "data"):
agent_id = _from_list(payload.get(key))
if agent_id:
return agent_id
return None
def _agent_key(agent: Agent) -> str:
session_key = agent.openclaw_session_id or ""
if session_key.startswith("agent:"):
@@ -383,24 +426,18 @@ def _render_agent_files(
async def _gateway_default_agent_id(
config: GatewayClientConfig,
*,
fallback_session_key: str | None = None,
) -> str | None:
try:
payload = await openclaw_call("agents.list", config=config)
except OpenClawGatewayError:
return None
if not isinstance(payload, dict):
return None
default_id = payload.get("defaultId") or payload.get("default_id")
if isinstance(default_id, str) and default_id:
return default_id
agents = payload.get("agents") or []
if isinstance(agents, list) and agents:
first = agents[0]
if isinstance(first, dict):
agent_id = first.get("id")
if isinstance(agent_id, str) and agent_id:
return _agent_id_from_session_key(fallback_session_key)
agent_id = _extract_agent_id(payload)
if agent_id:
return agent_id
return None
return _agent_id_from_session_key(fallback_session_key)
async def _patch_gateway_agent_list(
@@ -585,7 +622,10 @@ async def provision_main_agent(
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
await ensure_session(gateway.main_session_key, config=client_config, label="Main Agent")
agent_id = await _gateway_default_agent_id(client_config)
agent_id = await _gateway_default_agent_id(
client_config,
fallback_session_key=gateway.main_session_key,
)
if not agent_id:
raise OpenClawGatewayError("Unable to resolve gateway main agent id")

View File

@@ -1,6 +1,9 @@
from __future__ import annotations
import asyncio
import re
from collections.abc import Awaitable, Callable
from typing import TypeVar
from uuid import UUID, uuid4
from sqlmodel import col, select
@@ -19,12 +22,94 @@ from app.services.agent_provisioning import provision_agent, provision_main_agen
_TOOLS_KV_RE = re.compile(r"^(?P<key>[A-Z0-9_]+)=(?P<value>.*)$")
T = TypeVar("T")
def _slugify(value: str) -> str:
slug = re.sub(r"[^a-z0-9]+", "-", value.lower()).strip("-")
return slug or uuid4().hex
def _is_transient_gateway_error(exc: Exception) -> bool:
if not isinstance(exc, OpenClawGatewayError):
return False
message = str(exc).lower()
if not message:
return False
if "unsupported file" in message:
return False
if "received 1012" in message or "service restart" in message:
return True
if "http 503" in message or ("503" in message and "websocket" in message):
return True
if "temporar" in message:
return True
if "timeout" in message or "timed out" in message:
return True
if "connection closed" in message or "connection reset" in message:
return True
return False
async def _with_gateway_retry(
fn: Callable[[], Awaitable[T]],
*,
attempts: int = 3,
base_delay_s: float = 0.75,
) -> T:
for attempt in range(attempts):
try:
return await fn()
except Exception as exc:
if attempt >= attempts - 1 or not _is_transient_gateway_error(exc):
raise
await asyncio.sleep(base_delay_s * (2**attempt))
raise AssertionError("unreachable")
def _agent_id_from_session_key(session_key: str | None) -> str | None:
value = (session_key or "").strip()
if not value:
return None
if not value.startswith("agent:"):
return None
parts = value.split(":")
if len(parts) < 2:
return None
agent_id = parts[1].strip()
return agent_id or None
def _extract_agent_id(payload: object) -> str | None:
def _from_list(items: object) -> str | None:
if not isinstance(items, list):
return None
for item in items:
if isinstance(item, str) and item.strip():
return item.strip()
if not isinstance(item, dict):
continue
for key in ("id", "agentId", "agent_id"):
raw = item.get(key)
if isinstance(raw, str) and raw.strip():
return raw.strip()
return None
if isinstance(payload, list):
return _from_list(payload)
if not isinstance(payload, dict):
return None
for key in ("defaultId", "default_id", "defaultAgentId", "default_agent_id"):
raw = payload.get(key)
if isinstance(raw, str) and raw.strip():
return raw.strip()
for key in ("agents", "items", "list", "data"):
agent_id = _from_list(payload.get(key))
if agent_id:
return agent_id
return None
def _gateway_agent_id(agent: Agent) -> str:
session_key = agent.openclaw_session_id or ""
if session_key.startswith("agent:"):
@@ -94,24 +179,34 @@ async def _get_existing_auth_token(
return token or None
async def _gateway_default_agent_id(config: GatewayClientConfig) -> str | None:
async def _gateway_default_agent_id(
config: GatewayClientConfig,
*,
fallback_session_key: str | None = None,
) -> str | None:
last_error: OpenClawGatewayError | None = None
# Gateways may reject WS connects transiently under load (HTTP 503).
for attempt in range(3):
try:
payload = await openclaw_call("agents.list", config=config)
except OpenClawGatewayError:
return None
if not isinstance(payload, dict):
return None
default_id = payload.get("defaultId") or payload.get("default_id")
if isinstance(default_id, str) and default_id:
return default_id
agents = payload.get("agents") or []
if isinstance(agents, list) and agents:
first = agents[0]
if isinstance(first, dict):
agent_id = first.get("id")
if isinstance(agent_id, str) and agent_id:
agent_id = _extract_agent_id(payload)
if agent_id:
return agent_id
return None
break
except OpenClawGatewayError as exc:
last_error = exc
message = str(exc).lower()
if (
"503" not in message
and "temporar" not in message
and "rejected" not in message
and "timeout" not in message
):
break
await asyncio.sleep(0.5 * (2**attempt))
_ = last_error
return _agent_id_from_session_key(fallback_session_key)
async def sync_gateway_templates(
@@ -226,6 +321,7 @@ async def sync_gateway_templates(
)
try:
async def _do_provision() -> None:
await provision_agent(
agent,
board,
@@ -236,6 +332,8 @@ async def sync_gateway_templates(
force_bootstrap=force_bootstrap,
reset_session=reset_sessions,
)
await _with_gateway_retry(_do_provision)
result.agents_updated += 1
except Exception as exc: # pragma: no cover - gateway/network dependent
result.agents_skipped += 1
@@ -262,7 +360,10 @@ async def sync_gateway_templates(
)
return result
main_gateway_agent_id = await _gateway_default_agent_id(client_config)
main_gateway_agent_id = await _gateway_default_agent_id(
client_config,
fallback_session_key=gateway.main_session_key,
)
if not main_gateway_agent_id:
result.errors.append(
GatewayTemplatesSyncError(
@@ -277,6 +378,15 @@ async def sync_gateway_templates(
agent_gateway_id=main_gateway_agent_id, config=client_config
)
if not main_token:
if rotate_tokens:
raw_token = generate_agent_token()
main_agent.agent_token_hash = hash_agent_token(raw_token)
main_agent.updated_at = utcnow()
session.add(main_agent)
await session.commit()
await session.refresh(main_agent)
main_token = raw_token
else:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=main_agent.id,
@@ -286,7 +396,28 @@ async def sync_gateway_templates(
)
return result
if main_agent.agent_token_hash and not verify_agent_token(
main_token, main_agent.agent_token_hash
):
if rotate_tokens:
raw_token = generate_agent_token()
main_agent.agent_token_hash = hash_agent_token(raw_token)
main_agent.updated_at = utcnow()
session.add(main_agent)
await session.commit()
await session.refresh(main_agent)
main_token = raw_token
else:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=main_agent.id,
agent_name=main_agent.name,
message="Warning: AUTH_TOKEN in TOOLS.md does not match backend token hash (main agent auth may be broken).",
)
)
try:
async def _do_provision_main() -> None:
await provision_main_agent(
main_agent,
gateway,
@@ -296,6 +427,8 @@ async def sync_gateway_templates(
force_bootstrap=force_bootstrap,
reset_session=reset_sessions,
)
await _with_gateway_retry(_do_provision_main)
result.main_updated = True
except Exception as exc: # pragma: no cover - gateway/network dependent
result.errors.append(

View File

@@ -2,11 +2,16 @@ from __future__ import annotations
import argparse
import asyncio
import sys
from pathlib import Path
from uuid import UUID
from app.db.session import async_session_maker
from app.models.gateways import Gateway
from app.services.template_sync import sync_gateway_templates
BACKEND_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(BACKEND_ROOT))
from app.db.session import async_session_maker # noqa: E402
from app.models.gateways import Gateway # noqa: E402
from app.services.template_sync import sync_gateway_templates # noqa: E402
def _parse_args() -> argparse.Namespace: