refactor: standardize runtime annotation types across multiple files

This commit is contained in:
Abhimanyu Saharan
2026-02-09 17:24:21 +05:30
parent 7706943209
commit f5d592f61a
47 changed files with 2203 additions and 1413 deletions

View File

@@ -6,6 +6,7 @@ import hashlib
import json
import re
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast
from uuid import uuid4
@@ -88,6 +89,36 @@ MAIN_TEMPLATE_MAP = {
}
@dataclass(frozen=True, slots=True)
class ProvisionOptions:
"""Toggles controlling provisioning write/reset behavior."""
action: str = "provision"
force_bootstrap: bool = False
reset_session: bool = False
@dataclass(frozen=True, slots=True)
class AgentProvisionRequest:
"""Inputs required to provision a board-scoped agent."""
board: Board
gateway: Gateway
auth_token: str
user: User | None
options: ProvisionOptions = field(default_factory=ProvisionOptions)
@dataclass(frozen=True, slots=True)
class MainAgentProvisionRequest:
"""Inputs required to provision a gateway main agent."""
gateway: Gateway
auth_token: str
user: User | None
options: ProvisionOptions = field(default_factory=ProvisionOptions)
def _repo_root() -> Path:
return Path(__file__).resolve().parents[3]
@@ -114,31 +145,48 @@ def _agent_id_from_session_key(session_key: str | None) -> str | None:
return agent_id or None
def _extract_agent_id(payload: object) -> str | None: # noqa: C901
def _from_list(items: object) -> str | None:
if not isinstance(items, list):
return None
for item in items:
if isinstance(item, str) and item.strip():
return item.strip()
if not isinstance(item, dict):
continue
for key in ("id", "agentId", "agent_id"):
raw = item.get(key)
if isinstance(raw, str) and raw.strip():
return raw.strip()
def _clean_str(value: object) -> str | None:
if isinstance(value, str) and value.strip():
return value.strip()
return None
def _extract_agent_id_from_item(item: object) -> str | None:
if isinstance(item, str):
return _clean_str(item)
if not isinstance(item, dict):
return None
for key in ("id", "agentId", "agent_id"):
agent_id = _clean_str(item.get(key))
if agent_id:
return agent_id
return None
def _extract_agent_id_from_list(items: object) -> str | None:
if not isinstance(items, list):
return None
for item in items:
agent_id = _extract_agent_id_from_item(item)
if agent_id:
return agent_id
return None
def _extract_agent_id(payload: object) -> str | None:
default_keys = ("defaultId", "default_id", "defaultAgentId", "default_agent_id")
collection_keys = ("agents", "items", "list", "data")
if isinstance(payload, list):
return _from_list(payload)
return _extract_agent_id_from_list(payload)
if not isinstance(payload, dict):
return None
for key in ("defaultId", "default_id", "defaultAgentId", "default_agent_id"):
raw = payload.get(key)
if isinstance(raw, str) and raw.strip():
return raw.strip()
for key in ("agents", "items", "list", "data"):
agent_id = _from_list(payload.get(key))
for key in default_keys:
agent_id = _clean_str(payload.get(key))
if agent_id:
return agent_id
for key in collection_keys:
agent_id = _extract_agent_id_from_list(payload.get(key))
if agent_id:
return agent_id
return None
@@ -523,42 +571,44 @@ async def _patch_gateway_agent_list(
await openclaw_call("config.patch", params, config=config)
async def patch_gateway_agent_heartbeats( # noqa: C901
gateway: Gateway,
*,
entries: list[tuple[str, str, dict[str, Any]]],
) -> None:
"""Patch multiple agent heartbeat configs in a single gateway config.patch call.
Each entry is (agent_id, workspace_path, heartbeat_dict).
"""
if not gateway.url:
msg = "Gateway url is required"
raise OpenClawGatewayError(msg)
config = GatewayClientConfig(url=gateway.url, token=gateway.token)
async def _gateway_config_agent_list(
config: GatewayClientConfig,
) -> tuple[str | None, list[object]]:
cfg = await openclaw_call("config.get", config=config)
if not isinstance(cfg, dict):
msg = "config.get returned invalid payload"
raise OpenClawGatewayError(msg)
base_hash = cfg.get("hash")
data = cfg.get("config") or cfg.get("parsed") or {}
if not isinstance(data, dict):
msg = "config.get returned invalid config"
raise OpenClawGatewayError(msg)
agents_section = data.get("agents") or {}
lst = agents_section.get("list") or []
if not isinstance(lst, list):
agents_list = agents_section.get("list") or []
if not isinstance(agents_list, list):
msg = "config agents.list is not a list"
raise OpenClawGatewayError(msg)
return cfg.get("hash"), agents_list
entry_by_id: dict[str, tuple[str, dict[str, Any]]] = {
def _heartbeat_entry_map(
entries: list[tuple[str, str, dict[str, Any]]],
) -> dict[str, tuple[str, dict[str, Any]]]:
return {
agent_id: (workspace_path, heartbeat)
for agent_id, workspace_path, heartbeat in entries
}
def _updated_agent_list(
raw_list: list[object],
entry_by_id: dict[str, tuple[str, dict[str, Any]]],
) -> list[object]:
updated_ids: set[str] = set()
new_list: list[dict[str, Any]] = []
for raw_entry in lst:
new_list: list[object] = []
for raw_entry in raw_list:
if not isinstance(raw_entry, dict):
new_list.append(raw_entry)
continue
@@ -566,6 +616,7 @@ async def patch_gateway_agent_heartbeats( # noqa: C901
if not isinstance(agent_id, str) or agent_id not in entry_by_id:
new_list.append(raw_entry)
continue
workspace_path, heartbeat = entry_by_id[agent_id]
new_entry = dict(raw_entry)
new_entry["workspace"] = workspace_path
@@ -580,6 +631,26 @@ async def patch_gateway_agent_heartbeats( # noqa: C901
{"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat},
)
return new_list
async def patch_gateway_agent_heartbeats(
gateway: Gateway,
*,
entries: list[tuple[str, str, dict[str, Any]]],
) -> None:
"""Patch multiple agent heartbeat configs in a single gateway config.patch call.
Each entry is (agent_id, workspace_path, heartbeat_dict).
"""
if not gateway.url:
msg = "Gateway url is required"
raise OpenClawGatewayError(msg)
config = GatewayClientConfig(url=gateway.url, token=gateway.token)
base_hash, raw_list = await _gateway_config_agent_list(config)
entry_by_id = _heartbeat_entry_map(entries)
new_list = _updated_agent_list(raw_list, entry_by_id)
patch = {"agents": {"list": new_list}}
params = {"raw": json.dumps(patch)}
if base_hash:
@@ -656,18 +727,52 @@ async def _get_gateway_agent_entry(
return None
async def provision_agent( # noqa: C901, PLR0912, PLR0913
agent: Agent,
board: Board,
gateway: Gateway,
auth_token: str,
user: User | None,
def _should_include_bootstrap(
*,
action: str = "provision",
force_bootstrap: bool = False,
reset_session: bool = False,
action: str,
force_bootstrap: bool,
existing_files: dict[str, dict[str, Any]],
) -> bool:
if action != "update" or force_bootstrap:
return True
if not existing_files:
return False
entry = existing_files.get("BOOTSTRAP.md")
return not (entry and entry.get("missing") is True)
async def _set_agent_files(
*,
agent_id: str,
rendered: dict[str, str],
existing_files: dict[str, dict[str, Any]],
client_config: GatewayClientConfig,
) -> None:
for name, content in rendered.items():
if content == "":
continue
if name in PRESERVE_AGENT_EDITABLE_FILES:
entry = existing_files.get(name)
if entry and entry.get("missing") is not True:
continue
try:
await openclaw_call(
"agents.files.set",
{"agentId": agent_id, "name": name, "content": content},
config=client_config,
)
except OpenClawGatewayError as exc:
if "unsupported file" in str(exc).lower():
continue
raise
async def provision_agent(
agent: Agent,
request: AgentProvisionRequest,
) -> None:
"""Provision or update a regular board agent workspace."""
gateway = request.gateway
if not gateway.url:
return
if not gateway.workspace_root:
@@ -682,18 +787,21 @@ async def provision_agent( # noqa: C901, PLR0912, PLR0913
heartbeat = _heartbeat_config(agent)
await _patch_gateway_agent_list(agent_id, workspace_path, heartbeat, client_config)
context = _build_context(agent, board, gateway, auth_token, user)
context = _build_context(
agent,
request.board,
gateway,
request.auth_token,
request.user,
)
supported = set(await _supported_gateway_files(client_config))
supported.update({"USER.md", "SELF.md", "AUTONOMY.md"})
existing_files = await _gateway_agent_files_index(agent_id, client_config)
include_bootstrap = True
if action == "update" and not force_bootstrap:
if not existing_files:
include_bootstrap = False
else:
entry = existing_files.get("BOOTSTRAP.md")
if entry and entry.get("missing") is True:
include_bootstrap = False
include_bootstrap = _should_include_bootstrap(
action=request.options.action,
force_bootstrap=request.options.force_bootstrap,
existing_files=existing_files,
)
rendered = _render_agent_files(
context,
@@ -710,41 +818,22 @@ async def provision_agent( # noqa: C901, PLR0912, PLR0913
with suppress(OSError):
# Local workspace may not be writable/available; fall back to gateway API.
_ensure_workspace_file(workspace_path, name, content, overwrite=False)
for name, content in rendered.items():
if content == "":
continue
if name in PRESERVE_AGENT_EDITABLE_FILES:
# Never overwrite; only provision if missing.
entry = existing_files.get(name)
if entry and entry.get("missing") is not True:
continue
try:
await openclaw_call(
"agents.files.set",
{"agentId": agent_id, "name": name, "content": content},
config=client_config,
)
except OpenClawGatewayError as exc:
# Gateways may restrict file names. Skip unsupported files rather than
# failing provisioning for the entire agent.
if "unsupported file" in str(exc).lower():
continue
raise
if reset_session:
await _set_agent_files(
agent_id=agent_id,
rendered=rendered,
existing_files=existing_files,
client_config=client_config,
)
if request.options.reset_session:
await _reset_session(session_key, client_config)
async def provision_main_agent( # noqa: C901, PLR0912, PLR0913
async def provision_main_agent(
agent: Agent,
gateway: Gateway,
auth_token: str,
user: User | None,
*,
action: str = "provision",
force_bootstrap: bool = False,
reset_session: bool = False,
request: MainAgentProvisionRequest,
) -> None:
"""Provision or update the gateway main agent workspace."""
gateway = request.gateway
if not gateway.url:
return
if not gateway.main_session_key:
@@ -763,18 +852,15 @@ async def provision_main_agent( # noqa: C901, PLR0912, PLR0913
msg = "Unable to resolve gateway main agent id"
raise OpenClawGatewayError(msg)
context = _build_main_context(agent, gateway, auth_token, user)
context = _build_main_context(agent, gateway, request.auth_token, request.user)
supported = set(await _supported_gateway_files(client_config))
supported.update({"USER.md", "SELF.md", "AUTONOMY.md"})
existing_files = await _gateway_agent_files_index(agent_id, client_config)
include_bootstrap = action != "update" or force_bootstrap
if action == "update" and not force_bootstrap:
if not existing_files:
include_bootstrap = False
else:
entry = existing_files.get("BOOTSTRAP.md")
if entry and entry.get("missing") is True:
include_bootstrap = False
include_bootstrap = _should_include_bootstrap(
action=request.options.action,
force_bootstrap=request.options.force_bootstrap,
existing_files=existing_files,
)
rendered = _render_agent_files(
context,
@@ -783,24 +869,13 @@ async def provision_main_agent( # noqa: C901, PLR0912, PLR0913
include_bootstrap=include_bootstrap,
template_overrides=MAIN_TEMPLATE_MAP,
)
for name, content in rendered.items():
if content == "":
continue
if name in PRESERVE_AGENT_EDITABLE_FILES:
entry = existing_files.get(name)
if entry and entry.get("missing") is not True:
continue
try:
await openclaw_call(
"agents.files.set",
{"agentId": agent_id, "name": name, "content": content},
config=client_config,
)
except OpenClawGatewayError as exc:
if "unsupported file" in str(exc).lower():
continue
raise
if reset_session:
await _set_agent_files(
agent_id=agent_id,
rendered=rendered,
existing_files=existing_files,
client_config=client_config,
)
if request.options.reset_session:
await _reset_session(gateway.main_session_key, client_config)