feat(agent): enhance agent ID extraction and fallback handling in gateway integration
This commit is contained in:
@@ -86,6 +86,49 @@ def _slugify(value: str) -> str:
|
||||
return slug or uuid4().hex
|
||||
|
||||
|
||||
def _agent_id_from_session_key(session_key: str | None) -> str | None:
|
||||
value = (session_key or "").strip()
|
||||
if not value:
|
||||
return None
|
||||
if not value.startswith("agent:"):
|
||||
return None
|
||||
parts = value.split(":")
|
||||
if len(parts) < 2:
|
||||
return None
|
||||
agent_id = parts[1].strip()
|
||||
return agent_id or None
|
||||
|
||||
|
||||
def _extract_agent_id(payload: object) -> str | None:
|
||||
def _from_list(items: object) -> str | None:
|
||||
if not isinstance(items, list):
|
||||
return None
|
||||
for item in items:
|
||||
if isinstance(item, str) and item.strip():
|
||||
return item.strip()
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
for key in ("id", "agentId", "agent_id"):
|
||||
raw = item.get(key)
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
return raw.strip()
|
||||
return None
|
||||
|
||||
if isinstance(payload, list):
|
||||
return _from_list(payload)
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
for key in ("defaultId", "default_id", "defaultAgentId", "default_agent_id"):
|
||||
raw = payload.get(key)
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
return raw.strip()
|
||||
for key in ("agents", "items", "list", "data"):
|
||||
agent_id = _from_list(payload.get(key))
|
||||
if agent_id:
|
||||
return agent_id
|
||||
return None
|
||||
|
||||
|
||||
def _agent_key(agent: Agent) -> str:
|
||||
session_key = agent.openclaw_session_id or ""
|
||||
if session_key.startswith("agent:"):
|
||||
@@ -383,24 +426,18 @@ def _render_agent_files(
|
||||
|
||||
async def _gateway_default_agent_id(
|
||||
config: GatewayClientConfig,
|
||||
*,
|
||||
fallback_session_key: str | None = None,
|
||||
) -> str | None:
|
||||
try:
|
||||
payload = await openclaw_call("agents.list", config=config)
|
||||
except OpenClawGatewayError:
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
default_id = payload.get("defaultId") or payload.get("default_id")
|
||||
if isinstance(default_id, str) and default_id:
|
||||
return default_id
|
||||
agents = payload.get("agents") or []
|
||||
if isinstance(agents, list) and agents:
|
||||
first = agents[0]
|
||||
if isinstance(first, dict):
|
||||
agent_id = first.get("id")
|
||||
if isinstance(agent_id, str) and agent_id:
|
||||
return _agent_id_from_session_key(fallback_session_key)
|
||||
|
||||
agent_id = _extract_agent_id(payload)
|
||||
if agent_id:
|
||||
return agent_id
|
||||
return None
|
||||
return _agent_id_from_session_key(fallback_session_key)
|
||||
|
||||
|
||||
async def _patch_gateway_agent_list(
|
||||
@@ -585,7 +622,10 @@ async def provision_main_agent(
|
||||
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
|
||||
await ensure_session(gateway.main_session_key, config=client_config, label="Main Agent")
|
||||
|
||||
agent_id = await _gateway_default_agent_id(client_config)
|
||||
agent_id = await _gateway_default_agent_id(
|
||||
client_config,
|
||||
fallback_session_key=gateway.main_session_key,
|
||||
)
|
||||
if not agent_id:
|
||||
raise OpenClawGatewayError("Unable to resolve gateway main agent id")
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TypeVar
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import col, select
|
||||
@@ -19,12 +22,94 @@ from app.services.agent_provisioning import provision_agent, provision_main_agen
|
||||
|
||||
_TOOLS_KV_RE = re.compile(r"^(?P<key>[A-Z0-9_]+)=(?P<value>.*)$")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _slugify(value: str) -> str:
|
||||
slug = re.sub(r"[^a-z0-9]+", "-", value.lower()).strip("-")
|
||||
return slug or uuid4().hex
|
||||
|
||||
|
||||
def _is_transient_gateway_error(exc: Exception) -> bool:
|
||||
if not isinstance(exc, OpenClawGatewayError):
|
||||
return False
|
||||
message = str(exc).lower()
|
||||
if not message:
|
||||
return False
|
||||
if "unsupported file" in message:
|
||||
return False
|
||||
if "received 1012" in message or "service restart" in message:
|
||||
return True
|
||||
if "http 503" in message or ("503" in message and "websocket" in message):
|
||||
return True
|
||||
if "temporar" in message:
|
||||
return True
|
||||
if "timeout" in message or "timed out" in message:
|
||||
return True
|
||||
if "connection closed" in message or "connection reset" in message:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _with_gateway_retry(
|
||||
fn: Callable[[], Awaitable[T]],
|
||||
*,
|
||||
attempts: int = 3,
|
||||
base_delay_s: float = 0.75,
|
||||
) -> T:
|
||||
for attempt in range(attempts):
|
||||
try:
|
||||
return await fn()
|
||||
except Exception as exc:
|
||||
if attempt >= attempts - 1 or not _is_transient_gateway_error(exc):
|
||||
raise
|
||||
await asyncio.sleep(base_delay_s * (2**attempt))
|
||||
raise AssertionError("unreachable")
|
||||
|
||||
|
||||
def _agent_id_from_session_key(session_key: str | None) -> str | None:
|
||||
value = (session_key or "").strip()
|
||||
if not value:
|
||||
return None
|
||||
if not value.startswith("agent:"):
|
||||
return None
|
||||
parts = value.split(":")
|
||||
if len(parts) < 2:
|
||||
return None
|
||||
agent_id = parts[1].strip()
|
||||
return agent_id or None
|
||||
|
||||
|
||||
def _extract_agent_id(payload: object) -> str | None:
|
||||
def _from_list(items: object) -> str | None:
|
||||
if not isinstance(items, list):
|
||||
return None
|
||||
for item in items:
|
||||
if isinstance(item, str) and item.strip():
|
||||
return item.strip()
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
for key in ("id", "agentId", "agent_id"):
|
||||
raw = item.get(key)
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
return raw.strip()
|
||||
return None
|
||||
|
||||
if isinstance(payload, list):
|
||||
return _from_list(payload)
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
for key in ("defaultId", "default_id", "defaultAgentId", "default_agent_id"):
|
||||
raw = payload.get(key)
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
return raw.strip()
|
||||
for key in ("agents", "items", "list", "data"):
|
||||
agent_id = _from_list(payload.get(key))
|
||||
if agent_id:
|
||||
return agent_id
|
||||
return None
|
||||
|
||||
|
||||
def _gateway_agent_id(agent: Agent) -> str:
|
||||
session_key = agent.openclaw_session_id or ""
|
||||
if session_key.startswith("agent:"):
|
||||
@@ -94,24 +179,34 @@ async def _get_existing_auth_token(
|
||||
return token or None
|
||||
|
||||
|
||||
async def _gateway_default_agent_id(config: GatewayClientConfig) -> str | None:
|
||||
async def _gateway_default_agent_id(
|
||||
config: GatewayClientConfig,
|
||||
*,
|
||||
fallback_session_key: str | None = None,
|
||||
) -> str | None:
|
||||
last_error: OpenClawGatewayError | None = None
|
||||
# Gateways may reject WS connects transiently under load (HTTP 503).
|
||||
for attempt in range(3):
|
||||
try:
|
||||
payload = await openclaw_call("agents.list", config=config)
|
||||
except OpenClawGatewayError:
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
default_id = payload.get("defaultId") or payload.get("default_id")
|
||||
if isinstance(default_id, str) and default_id:
|
||||
return default_id
|
||||
agents = payload.get("agents") or []
|
||||
if isinstance(agents, list) and agents:
|
||||
first = agents[0]
|
||||
if isinstance(first, dict):
|
||||
agent_id = first.get("id")
|
||||
if isinstance(agent_id, str) and agent_id:
|
||||
agent_id = _extract_agent_id(payload)
|
||||
if agent_id:
|
||||
return agent_id
|
||||
return None
|
||||
break
|
||||
except OpenClawGatewayError as exc:
|
||||
last_error = exc
|
||||
message = str(exc).lower()
|
||||
if (
|
||||
"503" not in message
|
||||
and "temporar" not in message
|
||||
and "rejected" not in message
|
||||
and "timeout" not in message
|
||||
):
|
||||
break
|
||||
await asyncio.sleep(0.5 * (2**attempt))
|
||||
|
||||
_ = last_error
|
||||
return _agent_id_from_session_key(fallback_session_key)
|
||||
|
||||
|
||||
async def sync_gateway_templates(
|
||||
@@ -226,6 +321,7 @@ async def sync_gateway_templates(
|
||||
)
|
||||
|
||||
try:
|
||||
async def _do_provision() -> None:
|
||||
await provision_agent(
|
||||
agent,
|
||||
board,
|
||||
@@ -236,6 +332,8 @@ async def sync_gateway_templates(
|
||||
force_bootstrap=force_bootstrap,
|
||||
reset_session=reset_sessions,
|
||||
)
|
||||
|
||||
await _with_gateway_retry(_do_provision)
|
||||
result.agents_updated += 1
|
||||
except Exception as exc: # pragma: no cover - gateway/network dependent
|
||||
result.agents_skipped += 1
|
||||
@@ -262,7 +360,10 @@ async def sync_gateway_templates(
|
||||
)
|
||||
return result
|
||||
|
||||
main_gateway_agent_id = await _gateway_default_agent_id(client_config)
|
||||
main_gateway_agent_id = await _gateway_default_agent_id(
|
||||
client_config,
|
||||
fallback_session_key=gateway.main_session_key,
|
||||
)
|
||||
if not main_gateway_agent_id:
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
@@ -277,6 +378,15 @@ async def sync_gateway_templates(
|
||||
agent_gateway_id=main_gateway_agent_id, config=client_config
|
||||
)
|
||||
if not main_token:
|
||||
if rotate_tokens:
|
||||
raw_token = generate_agent_token()
|
||||
main_agent.agent_token_hash = hash_agent_token(raw_token)
|
||||
main_agent.updated_at = utcnow()
|
||||
session.add(main_agent)
|
||||
await session.commit()
|
||||
await session.refresh(main_agent)
|
||||
main_token = raw_token
|
||||
else:
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
agent_id=main_agent.id,
|
||||
@@ -286,7 +396,28 @@ async def sync_gateway_templates(
|
||||
)
|
||||
return result
|
||||
|
||||
if main_agent.agent_token_hash and not verify_agent_token(
|
||||
main_token, main_agent.agent_token_hash
|
||||
):
|
||||
if rotate_tokens:
|
||||
raw_token = generate_agent_token()
|
||||
main_agent.agent_token_hash = hash_agent_token(raw_token)
|
||||
main_agent.updated_at = utcnow()
|
||||
session.add(main_agent)
|
||||
await session.commit()
|
||||
await session.refresh(main_agent)
|
||||
main_token = raw_token
|
||||
else:
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
agent_id=main_agent.id,
|
||||
agent_name=main_agent.name,
|
||||
message="Warning: AUTH_TOKEN in TOOLS.md does not match backend token hash (main agent auth may be broken).",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
async def _do_provision_main() -> None:
|
||||
await provision_main_agent(
|
||||
main_agent,
|
||||
gateway,
|
||||
@@ -296,6 +427,8 @@ async def sync_gateway_templates(
|
||||
force_bootstrap=force_bootstrap,
|
||||
reset_session=reset_sessions,
|
||||
)
|
||||
|
||||
await _with_gateway_retry(_do_provision_main)
|
||||
result.main_updated = True
|
||||
except Exception as exc: # pragma: no cover - gateway/network dependent
|
||||
result.errors.append(
|
||||
|
||||
@@ -2,11 +2,16 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from uuid import UUID
|
||||
|
||||
from app.db.session import async_session_maker
|
||||
from app.models.gateways import Gateway
|
||||
from app.services.template_sync import sync_gateway_templates
|
||||
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(BACKEND_ROOT))
|
||||
|
||||
from app.db.session import async_session_maker # noqa: E402
|
||||
from app.models.gateways import Gateway # noqa: E402
|
||||
from app.services.template_sync import sync_gateway_templates # noqa: E402
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
|
||||
Reference in New Issue
Block a user