feat(agent): enhance agent ID extraction and fallback handling in gateway integration
This commit is contained in:
@@ -86,6 +86,49 @@ def _slugify(value: str) -> str:
|
|||||||
return slug or uuid4().hex
|
return slug or uuid4().hex
|
||||||
|
|
||||||
|
|
||||||
|
def _agent_id_from_session_key(session_key: str | None) -> str | None:
|
||||||
|
value = (session_key or "").strip()
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
if not value.startswith("agent:"):
|
||||||
|
return None
|
||||||
|
parts = value.split(":")
|
||||||
|
if len(parts) < 2:
|
||||||
|
return None
|
||||||
|
agent_id = parts[1].strip()
|
||||||
|
return agent_id or None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_agent_id(payload: object) -> str | None:
|
||||||
|
def _from_list(items: object) -> str | None:
|
||||||
|
if not isinstance(items, list):
|
||||||
|
return None
|
||||||
|
for item in items:
|
||||||
|
if isinstance(item, str) and item.strip():
|
||||||
|
return item.strip()
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
for key in ("id", "agentId", "agent_id"):
|
||||||
|
raw = item.get(key)
|
||||||
|
if isinstance(raw, str) and raw.strip():
|
||||||
|
return raw.strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(payload, list):
|
||||||
|
return _from_list(payload)
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return None
|
||||||
|
for key in ("defaultId", "default_id", "defaultAgentId", "default_agent_id"):
|
||||||
|
raw = payload.get(key)
|
||||||
|
if isinstance(raw, str) and raw.strip():
|
||||||
|
return raw.strip()
|
||||||
|
for key in ("agents", "items", "list", "data"):
|
||||||
|
agent_id = _from_list(payload.get(key))
|
||||||
|
if agent_id:
|
||||||
|
return agent_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _agent_key(agent: Agent) -> str:
|
def _agent_key(agent: Agent) -> str:
|
||||||
session_key = agent.openclaw_session_id or ""
|
session_key = agent.openclaw_session_id or ""
|
||||||
if session_key.startswith("agent:"):
|
if session_key.startswith("agent:"):
|
||||||
@@ -383,24 +426,18 @@ def _render_agent_files(
|
|||||||
|
|
||||||
async def _gateway_default_agent_id(
|
async def _gateway_default_agent_id(
|
||||||
config: GatewayClientConfig,
|
config: GatewayClientConfig,
|
||||||
|
*,
|
||||||
|
fallback_session_key: str | None = None,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
try:
|
try:
|
||||||
payload = await openclaw_call("agents.list", config=config)
|
payload = await openclaw_call("agents.list", config=config)
|
||||||
except OpenClawGatewayError:
|
except OpenClawGatewayError:
|
||||||
return None
|
return _agent_id_from_session_key(fallback_session_key)
|
||||||
if not isinstance(payload, dict):
|
|
||||||
return None
|
agent_id = _extract_agent_id(payload)
|
||||||
default_id = payload.get("defaultId") or payload.get("default_id")
|
if agent_id:
|
||||||
if isinstance(default_id, str) and default_id:
|
return agent_id
|
||||||
return default_id
|
return _agent_id_from_session_key(fallback_session_key)
|
||||||
agents = payload.get("agents") or []
|
|
||||||
if isinstance(agents, list) and agents:
|
|
||||||
first = agents[0]
|
|
||||||
if isinstance(first, dict):
|
|
||||||
agent_id = first.get("id")
|
|
||||||
if isinstance(agent_id, str) and agent_id:
|
|
||||||
return agent_id
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def _patch_gateway_agent_list(
|
async def _patch_gateway_agent_list(
|
||||||
@@ -585,7 +622,10 @@ async def provision_main_agent(
|
|||||||
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
|
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
|
||||||
await ensure_session(gateway.main_session_key, config=client_config, label="Main Agent")
|
await ensure_session(gateway.main_session_key, config=client_config, label="Main Agent")
|
||||||
|
|
||||||
agent_id = await _gateway_default_agent_id(client_config)
|
agent_id = await _gateway_default_agent_id(
|
||||||
|
client_config,
|
||||||
|
fallback_session_key=gateway.main_session_key,
|
||||||
|
)
|
||||||
if not agent_id:
|
if not agent_id:
|
||||||
raise OpenClawGatewayError("Unable to resolve gateway main agent id")
|
raise OpenClawGatewayError("Unable to resolve gateway main agent id")
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import re
|
import re
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import TypeVar
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
@@ -19,12 +22,94 @@ from app.services.agent_provisioning import provision_agent, provision_main_agen
|
|||||||
|
|
||||||
_TOOLS_KV_RE = re.compile(r"^(?P<key>[A-Z0-9_]+)=(?P<value>.*)$")
|
_TOOLS_KV_RE = re.compile(r"^(?P<key>[A-Z0-9_]+)=(?P<value>.*)$")
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def _slugify(value: str) -> str:
|
def _slugify(value: str) -> str:
|
||||||
slug = re.sub(r"[^a-z0-9]+", "-", value.lower()).strip("-")
|
slug = re.sub(r"[^a-z0-9]+", "-", value.lower()).strip("-")
|
||||||
return slug or uuid4().hex
|
return slug or uuid4().hex
|
||||||
|
|
||||||
|
|
||||||
|
def _is_transient_gateway_error(exc: Exception) -> bool:
|
||||||
|
if not isinstance(exc, OpenClawGatewayError):
|
||||||
|
return False
|
||||||
|
message = str(exc).lower()
|
||||||
|
if not message:
|
||||||
|
return False
|
||||||
|
if "unsupported file" in message:
|
||||||
|
return False
|
||||||
|
if "received 1012" in message or "service restart" in message:
|
||||||
|
return True
|
||||||
|
if "http 503" in message or ("503" in message and "websocket" in message):
|
||||||
|
return True
|
||||||
|
if "temporar" in message:
|
||||||
|
return True
|
||||||
|
if "timeout" in message or "timed out" in message:
|
||||||
|
return True
|
||||||
|
if "connection closed" in message or "connection reset" in message:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def _with_gateway_retry(
|
||||||
|
fn: Callable[[], Awaitable[T]],
|
||||||
|
*,
|
||||||
|
attempts: int = 3,
|
||||||
|
base_delay_s: float = 0.75,
|
||||||
|
) -> T:
|
||||||
|
for attempt in range(attempts):
|
||||||
|
try:
|
||||||
|
return await fn()
|
||||||
|
except Exception as exc:
|
||||||
|
if attempt >= attempts - 1 or not _is_transient_gateway_error(exc):
|
||||||
|
raise
|
||||||
|
await asyncio.sleep(base_delay_s * (2**attempt))
|
||||||
|
raise AssertionError("unreachable")
|
||||||
|
|
||||||
|
|
||||||
|
def _agent_id_from_session_key(session_key: str | None) -> str | None:
|
||||||
|
value = (session_key or "").strip()
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
if not value.startswith("agent:"):
|
||||||
|
return None
|
||||||
|
parts = value.split(":")
|
||||||
|
if len(parts) < 2:
|
||||||
|
return None
|
||||||
|
agent_id = parts[1].strip()
|
||||||
|
return agent_id or None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_agent_id(payload: object) -> str | None:
|
||||||
|
def _from_list(items: object) -> str | None:
|
||||||
|
if not isinstance(items, list):
|
||||||
|
return None
|
||||||
|
for item in items:
|
||||||
|
if isinstance(item, str) and item.strip():
|
||||||
|
return item.strip()
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
for key in ("id", "agentId", "agent_id"):
|
||||||
|
raw = item.get(key)
|
||||||
|
if isinstance(raw, str) and raw.strip():
|
||||||
|
return raw.strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(payload, list):
|
||||||
|
return _from_list(payload)
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return None
|
||||||
|
for key in ("defaultId", "default_id", "defaultAgentId", "default_agent_id"):
|
||||||
|
raw = payload.get(key)
|
||||||
|
if isinstance(raw, str) and raw.strip():
|
||||||
|
return raw.strip()
|
||||||
|
for key in ("agents", "items", "list", "data"):
|
||||||
|
agent_id = _from_list(payload.get(key))
|
||||||
|
if agent_id:
|
||||||
|
return agent_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _gateway_agent_id(agent: Agent) -> str:
|
def _gateway_agent_id(agent: Agent) -> str:
|
||||||
session_key = agent.openclaw_session_id or ""
|
session_key = agent.openclaw_session_id or ""
|
||||||
if session_key.startswith("agent:"):
|
if session_key.startswith("agent:"):
|
||||||
@@ -94,24 +179,34 @@ async def _get_existing_auth_token(
|
|||||||
return token or None
|
return token or None
|
||||||
|
|
||||||
|
|
||||||
async def _gateway_default_agent_id(config: GatewayClientConfig) -> str | None:
|
async def _gateway_default_agent_id(
|
||||||
try:
|
config: GatewayClientConfig,
|
||||||
payload = await openclaw_call("agents.list", config=config)
|
*,
|
||||||
except OpenClawGatewayError:
|
fallback_session_key: str | None = None,
|
||||||
return None
|
) -> str | None:
|
||||||
if not isinstance(payload, dict):
|
last_error: OpenClawGatewayError | None = None
|
||||||
return None
|
# Gateways may reject WS connects transiently under load (HTTP 503).
|
||||||
default_id = payload.get("defaultId") or payload.get("default_id")
|
for attempt in range(3):
|
||||||
if isinstance(default_id, str) and default_id:
|
try:
|
||||||
return default_id
|
payload = await openclaw_call("agents.list", config=config)
|
||||||
agents = payload.get("agents") or []
|
agent_id = _extract_agent_id(payload)
|
||||||
if isinstance(agents, list) and agents:
|
if agent_id:
|
||||||
first = agents[0]
|
|
||||||
if isinstance(first, dict):
|
|
||||||
agent_id = first.get("id")
|
|
||||||
if isinstance(agent_id, str) and agent_id:
|
|
||||||
return agent_id
|
return agent_id
|
||||||
return None
|
break
|
||||||
|
except OpenClawGatewayError as exc:
|
||||||
|
last_error = exc
|
||||||
|
message = str(exc).lower()
|
||||||
|
if (
|
||||||
|
"503" not in message
|
||||||
|
and "temporar" not in message
|
||||||
|
and "rejected" not in message
|
||||||
|
and "timeout" not in message
|
||||||
|
):
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.5 * (2**attempt))
|
||||||
|
|
||||||
|
_ = last_error
|
||||||
|
return _agent_id_from_session_key(fallback_session_key)
|
||||||
|
|
||||||
|
|
||||||
async def sync_gateway_templates(
|
async def sync_gateway_templates(
|
||||||
@@ -226,16 +321,19 @@ async def sync_gateway_templates(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await provision_agent(
|
async def _do_provision() -> None:
|
||||||
agent,
|
await provision_agent(
|
||||||
board,
|
agent,
|
||||||
gateway,
|
board,
|
||||||
auth_token,
|
gateway,
|
||||||
user,
|
auth_token,
|
||||||
action="update",
|
user,
|
||||||
force_bootstrap=force_bootstrap,
|
action="update",
|
||||||
reset_session=reset_sessions,
|
force_bootstrap=force_bootstrap,
|
||||||
)
|
reset_session=reset_sessions,
|
||||||
|
)
|
||||||
|
|
||||||
|
await _with_gateway_retry(_do_provision)
|
||||||
result.agents_updated += 1
|
result.agents_updated += 1
|
||||||
except Exception as exc: # pragma: no cover - gateway/network dependent
|
except Exception as exc: # pragma: no cover - gateway/network dependent
|
||||||
result.agents_skipped += 1
|
result.agents_skipped += 1
|
||||||
@@ -262,7 +360,10 @@ async def sync_gateway_templates(
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
main_gateway_agent_id = await _gateway_default_agent_id(client_config)
|
main_gateway_agent_id = await _gateway_default_agent_id(
|
||||||
|
client_config,
|
||||||
|
fallback_session_key=gateway.main_session_key,
|
||||||
|
)
|
||||||
if not main_gateway_agent_id:
|
if not main_gateway_agent_id:
|
||||||
result.errors.append(
|
result.errors.append(
|
||||||
GatewayTemplatesSyncError(
|
GatewayTemplatesSyncError(
|
||||||
@@ -277,25 +378,57 @@ async def sync_gateway_templates(
|
|||||||
agent_gateway_id=main_gateway_agent_id, config=client_config
|
agent_gateway_id=main_gateway_agent_id, config=client_config
|
||||||
)
|
)
|
||||||
if not main_token:
|
if not main_token:
|
||||||
result.errors.append(
|
if rotate_tokens:
|
||||||
GatewayTemplatesSyncError(
|
raw_token = generate_agent_token()
|
||||||
agent_id=main_agent.id,
|
main_agent.agent_token_hash = hash_agent_token(raw_token)
|
||||||
agent_name=main_agent.name,
|
main_agent.updated_at = utcnow()
|
||||||
message="Skipping main agent: unable to read AUTH_TOKEN from TOOLS.md.",
|
session.add(main_agent)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(main_agent)
|
||||||
|
main_token = raw_token
|
||||||
|
else:
|
||||||
|
result.errors.append(
|
||||||
|
GatewayTemplatesSyncError(
|
||||||
|
agent_id=main_agent.id,
|
||||||
|
agent_name=main_agent.name,
|
||||||
|
message="Skipping main agent: unable to read AUTH_TOKEN from TOOLS.md.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
if main_agent.agent_token_hash and not verify_agent_token(
|
||||||
|
main_token, main_agent.agent_token_hash
|
||||||
|
):
|
||||||
|
if rotate_tokens:
|
||||||
|
raw_token = generate_agent_token()
|
||||||
|
main_agent.agent_token_hash = hash_agent_token(raw_token)
|
||||||
|
main_agent.updated_at = utcnow()
|
||||||
|
session.add(main_agent)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(main_agent)
|
||||||
|
main_token = raw_token
|
||||||
|
else:
|
||||||
|
result.errors.append(
|
||||||
|
GatewayTemplatesSyncError(
|
||||||
|
agent_id=main_agent.id,
|
||||||
|
agent_name=main_agent.name,
|
||||||
|
message="Warning: AUTH_TOKEN in TOOLS.md does not match backend token hash (main agent auth may be broken).",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await provision_main_agent(
|
async def _do_provision_main() -> None:
|
||||||
main_agent,
|
await provision_main_agent(
|
||||||
gateway,
|
main_agent,
|
||||||
main_token,
|
gateway,
|
||||||
user,
|
main_token,
|
||||||
action="update",
|
user,
|
||||||
force_bootstrap=force_bootstrap,
|
action="update",
|
||||||
reset_session=reset_sessions,
|
force_bootstrap=force_bootstrap,
|
||||||
)
|
reset_session=reset_sessions,
|
||||||
|
)
|
||||||
|
|
||||||
|
await _with_gateway_retry(_do_provision_main)
|
||||||
result.main_updated = True
|
result.main_updated = True
|
||||||
except Exception as exc: # pragma: no cover - gateway/network dependent
|
except Exception as exc: # pragma: no cover - gateway/network dependent
|
||||||
result.errors.append(
|
result.errors.append(
|
||||||
|
|||||||
@@ -2,11 +2,16 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from app.db.session import async_session_maker
|
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||||
from app.models.gateways import Gateway
|
sys.path.insert(0, str(BACKEND_ROOT))
|
||||||
from app.services.template_sync import sync_gateway_templates
|
|
||||||
|
from app.db.session import async_session_maker # noqa: E402
|
||||||
|
from app.models.gateways import Gateway # noqa: E402
|
||||||
|
from app.services.template_sync import sync_gateway_templates # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
def _parse_args() -> argparse.Namespace:
|
def _parse_args() -> argparse.Namespace:
|
||||||
|
|||||||
Reference in New Issue
Block a user