refactor: enhance docstrings for clarity and consistency across multiple files
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
"""Gateway-facing agent provisioning and cleanup helpers."""
|
||||
# ruff: noqa: EM101, TRY003
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -176,7 +175,8 @@ def _heartbeat_template_name(agent: Agent) -> str:
|
||||
|
||||
def _workspace_path(agent: Agent, workspace_root: str) -> str:
|
||||
if not workspace_root:
|
||||
raise ValueError("gateway_workspace_root is required")
|
||||
msg = "gateway_workspace_root is required"
|
||||
raise ValueError(msg)
|
||||
root = workspace_root.rstrip("/")
|
||||
# Use agent key derived from session key when possible. This prevents collisions for
|
||||
# lead agents (session key includes board id) even if multiple boards share the same
|
||||
@@ -227,9 +227,11 @@ def _build_context(
|
||||
user: User | None,
|
||||
) -> dict[str, str]:
|
||||
if not gateway.workspace_root:
|
||||
raise ValueError("gateway_workspace_root is required")
|
||||
msg = "gateway_workspace_root is required"
|
||||
raise ValueError(msg)
|
||||
if not gateway.main_session_key:
|
||||
raise ValueError("gateway_main_session_key is required")
|
||||
msg = "gateway_main_session_key is required"
|
||||
raise ValueError(msg)
|
||||
agent_id = str(agent.id)
|
||||
workspace_root = gateway.workspace_root
|
||||
workspace_path = _workspace_path(agent, workspace_root)
|
||||
@@ -485,15 +487,18 @@ async def _patch_gateway_agent_list(
|
||||
) -> None:
|
||||
cfg = await openclaw_call("config.get", config=config)
|
||||
if not isinstance(cfg, dict):
|
||||
raise OpenClawGatewayError("config.get returned invalid payload")
|
||||
msg = "config.get returned invalid payload"
|
||||
raise OpenClawGatewayError(msg)
|
||||
base_hash = cfg.get("hash")
|
||||
data = cfg.get("config") or cfg.get("parsed") or {}
|
||||
if not isinstance(data, dict):
|
||||
raise OpenClawGatewayError("config.get returned invalid config")
|
||||
msg = "config.get returned invalid config"
|
||||
raise OpenClawGatewayError(msg)
|
||||
agents = data.get("agents") or {}
|
||||
lst = agents.get("list") or []
|
||||
if not isinstance(lst, list):
|
||||
raise OpenClawGatewayError("config agents.list is not a list")
|
||||
msg = "config agents.list is not a list"
|
||||
raise OpenClawGatewayError(msg)
|
||||
|
||||
updated = False
|
||||
new_list: list[dict[str, Any]] = []
|
||||
@@ -528,19 +533,23 @@ async def patch_gateway_agent_heartbeats( # noqa: C901
|
||||
Each entry is (agent_id, workspace_path, heartbeat_dict).
|
||||
"""
|
||||
if not gateway.url:
|
||||
raise OpenClawGatewayError("Gateway url is required")
|
||||
msg = "Gateway url is required"
|
||||
raise OpenClawGatewayError(msg)
|
||||
config = GatewayClientConfig(url=gateway.url, token=gateway.token)
|
||||
cfg = await openclaw_call("config.get", config=config)
|
||||
if not isinstance(cfg, dict):
|
||||
raise OpenClawGatewayError("config.get returned invalid payload")
|
||||
msg = "config.get returned invalid payload"
|
||||
raise OpenClawGatewayError(msg)
|
||||
base_hash = cfg.get("hash")
|
||||
data = cfg.get("config") or cfg.get("parsed") or {}
|
||||
if not isinstance(data, dict):
|
||||
raise OpenClawGatewayError("config.get returned invalid config")
|
||||
msg = "config.get returned invalid config"
|
||||
raise OpenClawGatewayError(msg)
|
||||
agents_section = data.get("agents") or {}
|
||||
lst = agents_section.get("list") or []
|
||||
if not isinstance(lst, list):
|
||||
raise OpenClawGatewayError("config agents.list is not a list")
|
||||
msg = "config agents.list is not a list"
|
||||
raise OpenClawGatewayError(msg)
|
||||
|
||||
entry_by_id: dict[str, tuple[str, dict[str, Any]]] = {
|
||||
agent_id: (workspace_path, heartbeat)
|
||||
@@ -581,7 +590,8 @@ async def patch_gateway_agent_heartbeats( # noqa: C901
|
||||
async def sync_gateway_agent_heartbeats(gateway: Gateway, agents: list[Agent]) -> None:
|
||||
"""Sync current Agent.heartbeat_config values to the gateway config."""
|
||||
if not gateway.workspace_root:
|
||||
raise OpenClawGatewayError("gateway workspace_root is required")
|
||||
msg = "gateway workspace_root is required"
|
||||
raise OpenClawGatewayError(msg)
|
||||
entries: list[tuple[str, str, dict[str, Any]]] = []
|
||||
for agent in agents:
|
||||
agent_id = _agent_key(agent)
|
||||
@@ -599,15 +609,18 @@ async def _remove_gateway_agent_list(
|
||||
) -> None:
|
||||
cfg = await openclaw_call("config.get", config=config)
|
||||
if not isinstance(cfg, dict):
|
||||
raise OpenClawGatewayError("config.get returned invalid payload")
|
||||
msg = "config.get returned invalid payload"
|
||||
raise OpenClawGatewayError(msg)
|
||||
base_hash = cfg.get("hash")
|
||||
data = cfg.get("config") or cfg.get("parsed") or {}
|
||||
if not isinstance(data, dict):
|
||||
raise OpenClawGatewayError("config.get returned invalid config")
|
||||
msg = "config.get returned invalid config"
|
||||
raise OpenClawGatewayError(msg)
|
||||
agents = data.get("agents") or {}
|
||||
lst = agents.get("list") or []
|
||||
if not isinstance(lst, list):
|
||||
raise OpenClawGatewayError("config agents.list is not a list")
|
||||
msg = "config agents.list is not a list"
|
||||
raise OpenClawGatewayError(msg)
|
||||
|
||||
new_list = [
|
||||
entry
|
||||
@@ -658,7 +671,8 @@ async def provision_agent( # noqa: C901, PLR0912, PLR0913
|
||||
if not gateway.url:
|
||||
return
|
||||
if not gateway.workspace_root:
|
||||
raise ValueError("gateway_workspace_root is required")
|
||||
msg = "gateway_workspace_root is required"
|
||||
raise ValueError(msg)
|
||||
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
|
||||
session_key = _session_key(agent)
|
||||
await ensure_session(session_key, config=client_config, label=agent.name)
|
||||
@@ -734,7 +748,8 @@ async def provision_main_agent( # noqa: C901, PLR0912, PLR0913
|
||||
if not gateway.url:
|
||||
return
|
||||
if not gateway.main_session_key:
|
||||
raise ValueError("gateway main_session_key is required")
|
||||
msg = "gateway main_session_key is required"
|
||||
raise ValueError(msg)
|
||||
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
|
||||
await ensure_session(
|
||||
gateway.main_session_key, config=client_config, label="Main Agent",
|
||||
@@ -745,7 +760,8 @@ async def provision_main_agent( # noqa: C901, PLR0912, PLR0913
|
||||
fallback_session_key=gateway.main_session_key,
|
||||
)
|
||||
if not agent_id:
|
||||
raise OpenClawGatewayError("Unable to resolve gateway main agent id")
|
||||
msg = "Unable to resolve gateway main agent id"
|
||||
raise OpenClawGatewayError(msg)
|
||||
|
||||
context = _build_main_context(agent, gateway, auth_token, user)
|
||||
supported = set(await _supported_gateway_files(client_config))
|
||||
@@ -796,7 +812,8 @@ async def cleanup_agent(
|
||||
if not gateway.url:
|
||||
return None
|
||||
if not gateway.workspace_root:
|
||||
raise ValueError("gateway_workspace_root is required")
|
||||
msg = "gateway_workspace_root is required"
|
||||
raise ValueError(msg)
|
||||
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
|
||||
|
||||
agent_id = _agent_key(agent)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Helpers for assembling board-group snapshot view models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import case, func
|
||||
@@ -22,48 +23,67 @@ from app.schemas.view_models import (
|
||||
|
||||
_STATUS_ORDER = {"in_progress": 0, "review": 1, "inbox": 2, "done": 3}
|
||||
_PRIORITY_ORDER = {"high": 0, "medium": 1, "low": 2}
|
||||
_RUNTIME_TYPE_REFERENCES = (UUID, AsyncSession)
|
||||
|
||||
|
||||
def _status_weight_expr() -> Any:
|
||||
def _status_weight_expr() -> object:
|
||||
"""Return a SQL expression that sorts task statuses by configured order."""
|
||||
whens = [(col(Task.status) == key, weight) for key, weight in _STATUS_ORDER.items()]
|
||||
return case(*whens, else_=99)
|
||||
|
||||
|
||||
def _priority_weight_expr() -> Any:
|
||||
whens = [(col(Task.priority) == key, weight) for key, weight in _PRIORITY_ORDER.items()]
|
||||
def _priority_weight_expr() -> object:
|
||||
"""Return a SQL expression that sorts task priorities by configured order."""
|
||||
whens = [
|
||||
(col(Task.priority) == key, weight)
|
||||
for key, weight in _PRIORITY_ORDER.items()
|
||||
]
|
||||
return case(*whens, else_=99)
|
||||
|
||||
|
||||
async def build_group_snapshot(
|
||||
async def _boards_for_group(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
group: BoardGroup,
|
||||
group_id: UUID,
|
||||
exclude_board_id: UUID | None = None,
|
||||
include_done: bool = False,
|
||||
per_board_task_limit: int = 5,
|
||||
) -> BoardGroupSnapshot:
|
||||
statement = Board.objects.filter_by(board_group_id=group.id).statement
|
||||
) -> list[Board]:
|
||||
"""Return boards belonging to a board group with optional exclusion."""
|
||||
statement = Board.objects.filter_by(board_group_id=group_id).statement
|
||||
if exclude_board_id is not None:
|
||||
statement = statement.where(col(Board.id) != exclude_board_id)
|
||||
boards = list(await session.exec(statement.order_by(func.lower(col(Board.name)).asc())))
|
||||
if not boards:
|
||||
return BoardGroupSnapshot(group=BoardGroupRead.model_validate(group, from_attributes=True))
|
||||
return list(
|
||||
await session.exec(
|
||||
statement.order_by(func.lower(col(Board.name)).asc()),
|
||||
),
|
||||
)
|
||||
|
||||
boards_by_id = {board.id: board for board in boards}
|
||||
board_ids = list(boards_by_id.keys())
|
||||
|
||||
async def _task_counts_by_board(
|
||||
session: AsyncSession,
|
||||
board_ids: list[UUID],
|
||||
) -> dict[UUID, dict[str, int]]:
|
||||
"""Return per-board task counts keyed by task status."""
|
||||
task_counts: dict[UUID, dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
for board_id, status_value, total in list(
|
||||
await session.exec(
|
||||
select(col(Task.board_id), col(Task.status), func.count(col(Task.id)))
|
||||
.where(col(Task.board_id).in_(board_ids))
|
||||
.group_by(col(Task.board_id), col(Task.status))
|
||||
)
|
||||
.group_by(col(Task.board_id), col(Task.status)),
|
||||
),
|
||||
):
|
||||
if board_id is None:
|
||||
continue
|
||||
task_counts[board_id][str(status_value)] = int(total or 0)
|
||||
return task_counts
|
||||
|
||||
|
||||
async def _ordered_tasks_for_boards(
|
||||
session: AsyncSession,
|
||||
board_ids: list[UUID],
|
||||
*,
|
||||
include_done: bool,
|
||||
) -> list[Task]:
|
||||
"""Return sorted tasks for boards, optionally excluding completed tasks."""
|
||||
task_statement = select(Task).where(col(Task.board_id).in_(board_ids))
|
||||
if not include_done:
|
||||
task_statement = task_statement.where(col(Task.status) != "done")
|
||||
@@ -74,62 +94,116 @@ async def build_group_snapshot(
|
||||
col(Task.updated_at).desc(),
|
||||
col(Task.created_at).desc(),
|
||||
)
|
||||
tasks = list(await session.exec(task_statement))
|
||||
return list(await session.exec(task_statement))
|
||||
|
||||
assigned_ids = {task.assigned_agent_id for task in tasks if task.assigned_agent_id is not None}
|
||||
agent_name_by_id: dict[UUID, str] = {}
|
||||
if assigned_ids:
|
||||
for agent_id, name in list(
|
||||
|
||||
async def _agent_names(
|
||||
session: AsyncSession,
|
||||
tasks: list[Task],
|
||||
) -> dict[UUID, str]:
|
||||
"""Return agent names keyed by assigned agent ids in the provided tasks."""
|
||||
assigned_ids = {
|
||||
task.assigned_agent_id
|
||||
for task in tasks
|
||||
if task.assigned_agent_id is not None
|
||||
}
|
||||
if not assigned_ids:
|
||||
return {}
|
||||
return dict(
|
||||
list(
|
||||
await session.exec(
|
||||
select(col(Agent.id), col(Agent.name)).where(col(Agent.id).in_(assigned_ids))
|
||||
)
|
||||
):
|
||||
agent_name_by_id[agent_id] = name
|
||||
select(col(Agent.id), col(Agent.name)).where(
|
||||
col(Agent.id).in_(assigned_ids),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _task_summaries_by_board(
|
||||
*,
|
||||
boards_by_id: dict[UUID, Board],
|
||||
tasks: list[Task],
|
||||
agent_name_by_id: dict[UUID, str],
|
||||
per_board_task_limit: int,
|
||||
) -> dict[UUID, list[BoardGroupTaskSummary]]:
|
||||
"""Build limited per-board task summary lists."""
|
||||
tasks_by_board: dict[UUID, list[BoardGroupTaskSummary]] = defaultdict(list)
|
||||
if per_board_task_limit > 0:
|
||||
for task in tasks:
|
||||
if task.board_id is None:
|
||||
continue
|
||||
current = tasks_by_board[task.board_id]
|
||||
if len(current) >= per_board_task_limit:
|
||||
continue
|
||||
board = boards_by_id.get(task.board_id)
|
||||
if board is None:
|
||||
continue
|
||||
current.append(
|
||||
BoardGroupTaskSummary(
|
||||
id=task.id,
|
||||
board_id=task.board_id,
|
||||
board_name=board.name,
|
||||
title=task.title,
|
||||
status=task.status,
|
||||
priority=task.priority,
|
||||
assigned_agent_id=task.assigned_agent_id,
|
||||
assignee=(
|
||||
agent_name_by_id.get(task.assigned_agent_id)
|
||||
if task.assigned_agent_id is not None
|
||||
else None
|
||||
),
|
||||
due_at=task.due_at,
|
||||
in_progress_at=task.in_progress_at,
|
||||
created_at=task.created_at,
|
||||
updated_at=task.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
snapshots: list[BoardGroupBoardSnapshot] = []
|
||||
for board in boards:
|
||||
board_read = BoardRead.model_validate(board, from_attributes=True)
|
||||
counts = dict(task_counts.get(board.id, {}))
|
||||
snapshots.append(
|
||||
BoardGroupBoardSnapshot(
|
||||
board=board_read,
|
||||
task_counts=counts,
|
||||
tasks=tasks_by_board.get(board.id, []),
|
||||
)
|
||||
if per_board_task_limit <= 0:
|
||||
return tasks_by_board
|
||||
for task in tasks:
|
||||
if task.board_id is None:
|
||||
continue
|
||||
current = tasks_by_board[task.board_id]
|
||||
if len(current) >= per_board_task_limit:
|
||||
continue
|
||||
board = boards_by_id.get(task.board_id)
|
||||
if board is None:
|
||||
continue
|
||||
current.append(
|
||||
BoardGroupTaskSummary(
|
||||
id=task.id,
|
||||
board_id=task.board_id,
|
||||
board_name=board.name,
|
||||
title=task.title,
|
||||
status=task.status,
|
||||
priority=task.priority,
|
||||
assigned_agent_id=task.assigned_agent_id,
|
||||
assignee=(
|
||||
agent_name_by_id.get(task.assigned_agent_id)
|
||||
if task.assigned_agent_id is not None
|
||||
else None
|
||||
),
|
||||
due_at=task.due_at,
|
||||
in_progress_at=task.in_progress_at,
|
||||
created_at=task.created_at,
|
||||
updated_at=task.updated_at,
|
||||
),
|
||||
)
|
||||
return tasks_by_board
|
||||
|
||||
|
||||
async def build_group_snapshot(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
group: BoardGroup,
|
||||
exclude_board_id: UUID | None = None,
|
||||
include_done: bool = False,
|
||||
per_board_task_limit: int = 5,
|
||||
) -> BoardGroupSnapshot:
|
||||
"""Build a board-group snapshot with board/task summaries."""
|
||||
boards = await _boards_for_group(
|
||||
session,
|
||||
group_id=group.id,
|
||||
exclude_board_id=exclude_board_id,
|
||||
)
|
||||
if not boards:
|
||||
return BoardGroupSnapshot(
|
||||
group=BoardGroupRead.model_validate(group, from_attributes=True),
|
||||
)
|
||||
boards_by_id = {board.id: board for board in boards}
|
||||
board_ids = list(boards_by_id.keys())
|
||||
task_counts = await _task_counts_by_board(session, board_ids)
|
||||
tasks = await _ordered_tasks_for_boards(
|
||||
session,
|
||||
board_ids,
|
||||
include_done=include_done,
|
||||
)
|
||||
agent_name_by_id = await _agent_names(session, tasks)
|
||||
tasks_by_board = _task_summaries_by_board(
|
||||
boards_by_id=boards_by_id,
|
||||
tasks=tasks,
|
||||
agent_name_by_id=agent_name_by_id,
|
||||
per_board_task_limit=per_board_task_limit,
|
||||
)
|
||||
snapshots = [
|
||||
BoardGroupBoardSnapshot(
|
||||
board=BoardRead.model_validate(board, from_attributes=True),
|
||||
task_counts=dict(task_counts.get(board.id, {})),
|
||||
tasks=tasks_by_board.get(board.id, []),
|
||||
)
|
||||
for board in boards
|
||||
]
|
||||
return BoardGroupSnapshot(
|
||||
group=BoardGroupRead.model_validate(group, from_attributes=True),
|
||||
boards=snapshots,
|
||||
@@ -144,6 +218,7 @@ async def build_board_group_snapshot(
|
||||
include_done: bool = False,
|
||||
per_board_task_limit: int = 5,
|
||||
) -> BoardGroupSnapshot:
|
||||
"""Build a board-group snapshot anchored to a board context."""
|
||||
if not board.board_group_id:
|
||||
return BoardGroupSnapshot(group=None, boards=[])
|
||||
group = await BoardGroup.objects.by_id(board.board_group_id).first(session)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Organization membership and board-access service helpers."""
|
||||
# ruff: noqa: D101, D103
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -38,19 +37,24 @@ ROLE_RANK = {"member": 0, "admin": 1, "owner": 2}
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OrganizationContext:
|
||||
"""Resolved organization and membership for the active user."""
|
||||
|
||||
organization: Organization
|
||||
member: OrganizationMember
|
||||
|
||||
|
||||
def is_org_admin(member: OrganizationMember) -> bool:
|
||||
"""Return whether a member has admin-level organization privileges."""
|
||||
return member.role in ADMIN_ROLES
|
||||
|
||||
|
||||
async def get_default_org(session: AsyncSession) -> Organization | None:
|
||||
"""Return the default personal organization if it exists."""
|
||||
return await Organization.objects.filter_by(name=DEFAULT_ORG_NAME).first(session)
|
||||
|
||||
|
||||
async def ensure_default_org(session: AsyncSession) -> Organization:
|
||||
"""Ensure and return the default personal organization."""
|
||||
org = await get_default_org(session)
|
||||
if org is not None:
|
||||
return org
|
||||
@@ -67,6 +71,7 @@ async def get_member(
|
||||
user_id: UUID,
|
||||
organization_id: UUID,
|
||||
) -> OrganizationMember | None:
|
||||
"""Fetch a membership by user id and organization id."""
|
||||
return await OrganizationMember.objects.filter_by(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
@@ -76,6 +81,7 @@ async def get_member(
|
||||
async def get_first_membership(
|
||||
session: AsyncSession, user_id: UUID,
|
||||
) -> OrganizationMember | None:
|
||||
"""Return the oldest membership for a user, if any."""
|
||||
return (
|
||||
await OrganizationMember.objects.filter_by(user_id=user_id)
|
||||
.order_by(col(OrganizationMember.created_at).asc())
|
||||
@@ -89,6 +95,7 @@ async def set_active_organization(
|
||||
user: User,
|
||||
organization_id: UUID,
|
||||
) -> OrganizationMember:
|
||||
"""Set a user's active organization and return the membership."""
|
||||
member = await get_member(session, user_id=user.id, organization_id=organization_id)
|
||||
if member is None:
|
||||
raise HTTPException(
|
||||
@@ -105,6 +112,7 @@ async def get_active_membership(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
) -> OrganizationMember | None:
|
||||
"""Resolve and normalize the user's currently active membership."""
|
||||
db_user = await User.objects.by_id(user.id).first(session)
|
||||
if db_user is None:
|
||||
db_user = user
|
||||
@@ -151,6 +159,7 @@ async def accept_invite(
|
||||
invite: OrganizationInvite,
|
||||
user: User,
|
||||
) -> OrganizationMember:
|
||||
"""Accept an invite and create membership plus scoped board access rows."""
|
||||
now = utcnow()
|
||||
member = OrganizationMember(
|
||||
organization_id=invite.organization_id,
|
||||
@@ -200,6 +209,7 @@ async def accept_invite(
|
||||
async def ensure_member_for_user(
|
||||
session: AsyncSession, user: User,
|
||||
) -> OrganizationMember:
|
||||
"""Ensure a user has some membership, creating one if necessary."""
|
||||
existing = await get_active_membership(session, user)
|
||||
if existing is not None:
|
||||
return existing
|
||||
@@ -237,10 +247,12 @@ async def ensure_member_for_user(
|
||||
|
||||
|
||||
def member_all_boards_read(member: OrganizationMember) -> bool:
|
||||
"""Return whether the member has organization-wide read access."""
|
||||
return member.all_boards_read or member.all_boards_write
|
||||
|
||||
|
||||
def member_all_boards_write(member: OrganizationMember) -> bool:
|
||||
"""Return whether the member has organization-wide write access."""
|
||||
return member.all_boards_write
|
||||
|
||||
|
||||
@@ -251,6 +263,7 @@ async def has_board_access(
|
||||
board: Board,
|
||||
write: bool,
|
||||
) -> bool:
|
||||
"""Return whether a member has board access for the requested mode."""
|
||||
if member.organization_id != board.organization_id:
|
||||
return False
|
||||
if write:
|
||||
@@ -276,6 +289,7 @@ async def require_board_access(
|
||||
board: Board,
|
||||
write: bool,
|
||||
) -> OrganizationMember:
|
||||
"""Require board access for a user and return matching membership."""
|
||||
member = await get_member(
|
||||
session, user_id=user.id, organization_id=board.organization_id,
|
||||
)
|
||||
@@ -293,6 +307,7 @@ async def require_board_access(
|
||||
def board_access_filter(
|
||||
member: OrganizationMember, *, write: bool,
|
||||
) -> ColumnElement[bool]:
|
||||
"""Build a SQL filter expression for boards visible to a member."""
|
||||
if write and member_all_boards_write(member):
|
||||
return col(Board.organization_id) == member.organization_id
|
||||
if not write and member_all_boards_read(member):
|
||||
@@ -320,6 +335,7 @@ async def list_accessible_board_ids(
|
||||
member: OrganizationMember,
|
||||
write: bool,
|
||||
) -> list[UUID]:
|
||||
"""List board ids accessible to a member for read or write mode."""
|
||||
if (write and member_all_boards_write(member)) or (
|
||||
not write and member_all_boards_read(member)
|
||||
):
|
||||
@@ -354,6 +370,7 @@ async def apply_member_access_update(
|
||||
member: OrganizationMember,
|
||||
update: OrganizationMemberAccessUpdate,
|
||||
) -> None:
|
||||
"""Replace explicit member board-access rows from an access update."""
|
||||
now = utcnow()
|
||||
member.all_boards_read = update.all_boards_read
|
||||
member.all_boards_write = update.all_boards_write
|
||||
@@ -390,6 +407,7 @@ async def apply_invite_board_access(
|
||||
invite: OrganizationInvite,
|
||||
entries: Iterable[OrganizationBoardAccessSpec],
|
||||
) -> None:
|
||||
"""Replace explicit invite board-access rows for an invite."""
|
||||
await crud.delete_where(
|
||||
session,
|
||||
OrganizationInviteBoardAccess,
|
||||
@@ -414,10 +432,12 @@ async def apply_invite_board_access(
|
||||
|
||||
|
||||
def normalize_invited_email(email: str) -> str:
|
||||
"""Normalize an invited email address for storage/comparison."""
|
||||
return email.strip().lower()
|
||||
|
||||
|
||||
def normalize_role(role: str) -> str:
|
||||
"""Normalize a role string and default empty values to `member`."""
|
||||
return role.strip().lower() or "member"
|
||||
|
||||
|
||||
@@ -433,6 +453,7 @@ async def apply_invite_to_member(
|
||||
member: OrganizationMember,
|
||||
invite: OrganizationInvite,
|
||||
) -> None:
|
||||
"""Apply invite role/access grants onto an existing organization member."""
|
||||
now = utcnow()
|
||||
member_changed = False
|
||||
invite_role = normalize_role(invite.role or "member")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Task-dependency helpers for validation, querying, and replacement."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
@@ -14,6 +16,7 @@ from app.models.task_dependencies import TaskDependency
|
||||
from app.models.tasks import Task
|
||||
|
||||
DONE_STATUS: Final[str] = "done"
|
||||
_RUNTIME_TYPE_REFERENCES = (UUID, AsyncSession, Mapping, Sequence)
|
||||
|
||||
|
||||
def _dedupe_uuid_list(values: Sequence[UUID]) -> list[UUID]:
|
||||
@@ -34,6 +37,7 @@ async def dependency_ids_by_task_id(
|
||||
board_id: UUID,
|
||||
task_ids: Sequence[UUID],
|
||||
) -> dict[UUID, list[UUID]]:
|
||||
"""Return dependency ids keyed by task id for tasks on a board."""
|
||||
if not task_ids:
|
||||
return {}
|
||||
rows = list(
|
||||
@@ -41,8 +45,8 @@ async def dependency_ids_by_task_id(
|
||||
select(col(TaskDependency.task_id), col(TaskDependency.depends_on_task_id))
|
||||
.where(col(TaskDependency.board_id) == board_id)
|
||||
.where(col(TaskDependency.task_id).in_(task_ids))
|
||||
.order_by(col(TaskDependency.created_at).asc())
|
||||
)
|
||||
.order_by(col(TaskDependency.created_at).asc()),
|
||||
),
|
||||
)
|
||||
mapping: dict[UUID, list[UUID]] = defaultdict(list)
|
||||
for task_id, depends_on_task_id in rows:
|
||||
@@ -56,16 +60,17 @@ async def dependency_status_by_id(
|
||||
board_id: UUID,
|
||||
dependency_ids: Sequence[UUID],
|
||||
) -> dict[UUID, str]:
|
||||
"""Return dependency status values keyed by dependency task id."""
|
||||
if not dependency_ids:
|
||||
return {}
|
||||
rows = list(
|
||||
await session.exec(
|
||||
select(col(Task.id), col(Task.status))
|
||||
.where(col(Task.board_id) == board_id)
|
||||
.where(col(Task.id).in_(dependency_ids))
|
||||
)
|
||||
.where(col(Task.id).in_(dependency_ids)),
|
||||
),
|
||||
)
|
||||
return {task_id: status_value for task_id, status_value in rows}
|
||||
return dict(rows)
|
||||
|
||||
|
||||
def blocked_by_dependency_ids(
|
||||
@@ -73,11 +78,12 @@ def blocked_by_dependency_ids(
|
||||
dependency_ids: Sequence[UUID],
|
||||
status_by_id: Mapping[UUID, str],
|
||||
) -> list[UUID]:
|
||||
blocked: list[UUID] = []
|
||||
for dep_id in dependency_ids:
|
||||
if status_by_id.get(dep_id) != DONE_STATUS:
|
||||
blocked.append(dep_id)
|
||||
return blocked
|
||||
"""Return dependency ids that are not yet in the done status."""
|
||||
return [
|
||||
dep_id
|
||||
for dep_id in dependency_ids
|
||||
if status_by_id.get(dep_id) != DONE_STATUS
|
||||
]
|
||||
|
||||
|
||||
async def blocked_by_for_task(
|
||||
@@ -87,6 +93,7 @@ async def blocked_by_for_task(
|
||||
task_id: UUID,
|
||||
dependency_ids: Sequence[UUID] | None = None,
|
||||
) -> list[UUID]:
|
||||
"""Return unresolved dependency ids for the provided task."""
|
||||
dep_ids = list(dependency_ids or [])
|
||||
if dependency_ids is None:
|
||||
deps_map = await dependency_ids_by_task_id(
|
||||
@@ -97,11 +104,16 @@ async def blocked_by_for_task(
|
||||
dep_ids = deps_map.get(task_id, [])
|
||||
if not dep_ids:
|
||||
return []
|
||||
status_by_id = await dependency_status_by_id(session, board_id=board_id, dependency_ids=dep_ids)
|
||||
status_by_id = await dependency_status_by_id(
|
||||
session,
|
||||
board_id=board_id,
|
||||
dependency_ids=dep_ids,
|
||||
)
|
||||
return blocked_by_dependency_ids(dependency_ids=dep_ids, status_by_id=status_by_id)
|
||||
|
||||
|
||||
def _has_cycle(nodes: Sequence[UUID], edges: Mapping[UUID, set[UUID]]) -> bool:
|
||||
"""Detect cycles in a directed dependency graph."""
|
||||
visited: set[UUID] = set()
|
||||
in_stack: set[UUID] = set()
|
||||
|
||||
@@ -118,10 +130,7 @@ def _has_cycle(nodes: Sequence[UUID], edges: Mapping[UUID, set[UUID]]) -> bool:
|
||||
in_stack.remove(node)
|
||||
return False
|
||||
|
||||
for node in nodes:
|
||||
if dfs(node):
|
||||
return True
|
||||
return False
|
||||
return any(dfs(node) for node in nodes)
|
||||
|
||||
|
||||
async def validate_dependency_update(
|
||||
@@ -131,6 +140,7 @@ async def validate_dependency_update(
|
||||
task_id: UUID,
|
||||
depends_on_task_ids: Sequence[UUID],
|
||||
) -> list[UUID]:
|
||||
"""Validate a dependency update and return normalized dependency ids."""
|
||||
normalized = _dedupe_uuid_list(depends_on_task_ids)
|
||||
if task_id in normalized:
|
||||
raise HTTPException(
|
||||
@@ -145,8 +155,8 @@ async def validate_dependency_update(
|
||||
await session.exec(
|
||||
select(col(Task.id))
|
||||
.where(col(Task.board_id) == board_id)
|
||||
.where(col(Task.id).in_(normalized))
|
||||
)
|
||||
.where(col(Task.id).in_(normalized)),
|
||||
),
|
||||
)
|
||||
missing = [dep_id for dep_id in normalized if dep_id not in existing_ids]
|
||||
if missing:
|
||||
@@ -159,13 +169,18 @@ async def validate_dependency_update(
|
||||
)
|
||||
|
||||
# Ensure the dependency graph is acyclic after applying the update.
|
||||
task_ids = list(await session.exec(select(col(Task.id)).where(col(Task.board_id) == board_id)))
|
||||
task_ids = list(
|
||||
await session.exec(
|
||||
select(col(Task.id)).where(col(Task.board_id) == board_id),
|
||||
),
|
||||
)
|
||||
rows = list(
|
||||
await session.exec(
|
||||
select(col(TaskDependency.task_id), col(TaskDependency.depends_on_task_id)).where(
|
||||
col(TaskDependency.board_id) == board_id
|
||||
)
|
||||
)
|
||||
select(
|
||||
col(TaskDependency.task_id),
|
||||
col(TaskDependency.depends_on_task_id),
|
||||
).where(col(TaskDependency.board_id) == board_id),
|
||||
),
|
||||
)
|
||||
edges: dict[UUID, set[UUID]] = defaultdict(set)
|
||||
for src, dst in rows:
|
||||
@@ -188,6 +203,7 @@ async def replace_task_dependencies(
|
||||
task_id: UUID,
|
||||
depends_on_task_ids: Sequence[UUID],
|
||||
) -> list[UUID]:
|
||||
"""Replace dependencies for a task and return the normalized dependency ids."""
|
||||
normalized = await validate_dependency_update(
|
||||
session,
|
||||
board_id=board_id,
|
||||
@@ -207,7 +223,7 @@ async def replace_task_dependencies(
|
||||
board_id=board_id,
|
||||
task_id=task_id,
|
||||
depends_on_task_id=dep_id,
|
||||
)
|
||||
),
|
||||
)
|
||||
return normalized
|
||||
|
||||
@@ -218,9 +234,10 @@ async def dependent_task_ids(
|
||||
board_id: UUID,
|
||||
dependency_task_id: UUID,
|
||||
) -> list[UUID]:
|
||||
"""Return task ids that depend on the provided dependency task id."""
|
||||
rows = await session.exec(
|
||||
select(col(TaskDependency.task_id))
|
||||
.where(col(TaskDependency.board_id) == board_id)
|
||||
.where(col(TaskDependency.depends_on_task_id) == dependency_task_id)
|
||||
.where(col(TaskDependency.depends_on_task_id) == dependency_task_id),
|
||||
)
|
||||
return list(rows)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
"""Gateway template synchronization orchestration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
@@ -11,7 +14,11 @@ from sqlalchemy import func
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.agent_tokens import generate_agent_token, hash_agent_token, verify_agent_token
|
||||
from app.core.agent_tokens import (
|
||||
generate_agent_token,
|
||||
hash_agent_token,
|
||||
verify_agent_token,
|
||||
)
|
||||
from app.core.time import utcnow
|
||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||
from app.integrations.openclaw_gateway import OpenClawGatewayError, openclaw_call
|
||||
@@ -49,6 +56,31 @@ _TRANSIENT_GATEWAY_ERROR_MARKERS = (
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
_SECURE_RANDOM = random.SystemRandom()
|
||||
_RUNTIME_TYPE_REFERENCES = (Awaitable, Callable, AsyncSession, Gateway, User, UUID)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GatewayTemplateSyncOptions:
|
||||
"""Runtime options controlling gateway template synchronization."""
|
||||
|
||||
user: User | None
|
||||
include_main: bool = True
|
||||
reset_sessions: bool = False
|
||||
rotate_tokens: bool = False
|
||||
force_bootstrap: bool = False
|
||||
board_id: UUID | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _SyncContext:
|
||||
"""Shared state passed to sync helper functions."""
|
||||
|
||||
session: AsyncSession
|
||||
gateway: Gateway
|
||||
config: GatewayClientConfig
|
||||
backoff: _GatewayBackoff
|
||||
options: GatewayTemplateSyncOptions
|
||||
|
||||
|
||||
def _slugify(value: str) -> str:
|
||||
@@ -70,7 +102,10 @@ def _is_transient_gateway_error(exc: Exception) -> bool:
|
||||
|
||||
|
||||
def _gateway_timeout_message(exc: OpenClawGatewayError) -> str:
|
||||
return f"Gateway unreachable after 10 minutes (template sync timeout). Last error: {exc}"
|
||||
return (
|
||||
"Gateway unreachable after 10 minutes (template sync timeout). "
|
||||
f"Last error: {exc}"
|
||||
)
|
||||
|
||||
|
||||
class _GatewayBackoff:
|
||||
@@ -91,16 +126,25 @@ class _GatewayBackoff:
|
||||
def reset(self) -> None:
|
||||
self._delay_s = self._base_delay_s
|
||||
|
||||
async def _attempt(
|
||||
self,
|
||||
fn: Callable[[], Awaitable[T]],
|
||||
) -> tuple[T | None, OpenClawGatewayError | None]:
|
||||
try:
|
||||
return await fn(), None
|
||||
except OpenClawGatewayError as exc:
|
||||
return None, exc
|
||||
|
||||
async def run(self, fn: Callable[[], Awaitable[T]]) -> T:
|
||||
# Use per-call deadlines so long-running syncs can still tolerate a later
|
||||
# gateway restart without having an already-expired retry window.
|
||||
deadline_s = asyncio.get_running_loop().time() + self._timeout_s
|
||||
while True:
|
||||
try:
|
||||
value = await fn()
|
||||
except OpenClawGatewayError as exc:
|
||||
value, error = await self._attempt(fn)
|
||||
if error is not None:
|
||||
exc = error
|
||||
if not _is_transient_gateway_error(exc):
|
||||
raise
|
||||
raise exc
|
||||
now = asyncio.get_running_loop().time()
|
||||
remaining = deadline_s - now
|
||||
if remaining <= 0:
|
||||
@@ -108,13 +152,16 @@ class _GatewayBackoff:
|
||||
|
||||
sleep_s = min(self._delay_s, remaining)
|
||||
if self._jitter:
|
||||
sleep_s *= 1.0 + random.uniform(-self._jitter, self._jitter)
|
||||
sleep_s *= 1.0 + _SECURE_RANDOM.uniform(
|
||||
-self._jitter,
|
||||
self._jitter,
|
||||
)
|
||||
sleep_s = max(0.0, min(sleep_s, remaining))
|
||||
await asyncio.sleep(sleep_s)
|
||||
self._delay_s = min(self._delay_s * 2.0, self._max_delay_s)
|
||||
else:
|
||||
self.reset()
|
||||
return value
|
||||
continue
|
||||
self.reset()
|
||||
return value
|
||||
|
||||
|
||||
async def _with_gateway_retry(
|
||||
@@ -138,23 +185,25 @@ def _agent_id_from_session_key(session_key: str | None) -> str | None:
|
||||
return agent_id or None
|
||||
|
||||
|
||||
def _extract_agent_id(payload: object) -> str | None:
|
||||
def _from_list(items: object) -> str | None:
|
||||
if not isinstance(items, list):
|
||||
return None
|
||||
for item in items:
|
||||
if isinstance(item, str) and item.strip():
|
||||
return item.strip()
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
for key in ("id", "agentId", "agent_id"):
|
||||
raw = item.get(key)
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
return raw.strip()
|
||||
def _extract_agent_id_from_list(items: object) -> str | None:
|
||||
if not isinstance(items, list):
|
||||
return None
|
||||
for item in items:
|
||||
if isinstance(item, str) and item.strip():
|
||||
return item.strip()
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
for key in ("id", "agentId", "agent_id"):
|
||||
raw = item.get(key)
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
return raw.strip()
|
||||
return None
|
||||
|
||||
|
||||
def _extract_agent_id(payload: object) -> str | None:
|
||||
"""Extract a default gateway agent id from common list payload shapes."""
|
||||
if isinstance(payload, list):
|
||||
return _from_list(payload)
|
||||
return _extract_agent_id_from_list(payload)
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
for key in ("defaultId", "default_id", "defaultAgentId", "default_agent_id"):
|
||||
@@ -162,7 +211,7 @@ def _extract_agent_id(payload: object) -> str | None:
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
return raw.strip()
|
||||
for key in ("agents", "items", "list", "data"):
|
||||
agent_id = _from_list(payload.get(key))
|
||||
agent_id = _extract_agent_id_from_list(payload.get(key))
|
||||
if agent_id:
|
||||
return agent_id
|
||||
return None
|
||||
@@ -212,9 +261,6 @@ async def _get_agent_file(
|
||||
if isinstance(payload, str):
|
||||
return payload
|
||||
if isinstance(payload, dict):
|
||||
# Common shapes:
|
||||
# - {"name": "...", "content": "..."}
|
||||
# - {"file": {"name": "...", "content": "..." }}
|
||||
content = payload.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
@@ -291,18 +337,53 @@ async def _paused_board_ids(session: AsyncSession, board_ids: list[UUID]) -> set
|
||||
return paused
|
||||
|
||||
|
||||
async def sync_gateway_templates(
|
||||
session: AsyncSession,
|
||||
def _append_sync_error(
|
||||
result: GatewayTemplatesSyncResult,
|
||||
*,
|
||||
message: str,
|
||||
agent: Agent | None = None,
|
||||
board: Board | None = None,
|
||||
) -> None:
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
agent_id=agent.id if agent else None,
|
||||
agent_name=agent.name if agent else None,
|
||||
board_id=board.id if board else None,
|
||||
message=message,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def _rotate_agent_token(session: AsyncSession, agent: Agent) -> str:
|
||||
token = generate_agent_token()
|
||||
agent.agent_token_hash = hash_agent_token(token)
|
||||
agent.updated_at = utcnow()
|
||||
session.add(agent)
|
||||
await session.commit()
|
||||
await session.refresh(agent)
|
||||
return token
|
||||
|
||||
|
||||
async def _ping_gateway(ctx: _SyncContext, result: GatewayTemplatesSyncResult) -> bool:
|
||||
try:
|
||||
async def _do_ping() -> object:
|
||||
return await openclaw_call("agents.list", config=ctx.config)
|
||||
|
||||
await ctx.backoff.run(_do_ping)
|
||||
except (TimeoutError, OpenClawGatewayError) as exc:
|
||||
_append_sync_error(result, message=str(exc))
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def _base_result(
|
||||
gateway: Gateway,
|
||||
*,
|
||||
user: User | None,
|
||||
include_main: bool = True,
|
||||
reset_sessions: bool = False,
|
||||
rotate_tokens: bool = False,
|
||||
force_bootstrap: bool = False,
|
||||
board_id: UUID | None = None,
|
||||
include_main: bool,
|
||||
reset_sessions: bool,
|
||||
) -> GatewayTemplatesSyncResult:
|
||||
result = GatewayTemplatesSyncResult(
|
||||
return GatewayTemplatesSyncResult(
|
||||
gateway_id=gateway.id,
|
||||
include_main=include_main,
|
||||
reset_sessions=reset_sessions,
|
||||
@@ -310,45 +391,239 @@ async def sync_gateway_templates(
|
||||
agents_skipped=0,
|
||||
main_updated=False,
|
||||
)
|
||||
|
||||
|
||||
def _boards_by_id(
|
||||
boards: list[Board],
|
||||
*,
|
||||
board_id: UUID | None,
|
||||
) -> dict[UUID, Board] | None:
|
||||
boards_by_id = {board.id: board for board in boards}
|
||||
if board_id is None:
|
||||
return boards_by_id
|
||||
board = boards_by_id.get(board_id)
|
||||
if board is None:
|
||||
return None
|
||||
return {board_id: board}
|
||||
|
||||
|
||||
async def _resolve_agent_auth_token(
|
||||
ctx: _SyncContext,
|
||||
result: GatewayTemplatesSyncResult,
|
||||
agent: Agent,
|
||||
board: Board | None,
|
||||
*,
|
||||
agent_gateway_id: str,
|
||||
) -> tuple[str | None, bool]:
|
||||
try:
|
||||
auth_token = await _get_existing_auth_token(
|
||||
agent_gateway_id=agent_gateway_id,
|
||||
config=ctx.config,
|
||||
backoff=ctx.backoff,
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
_append_sync_error(result, agent=agent, board=board, message=str(exc))
|
||||
return None, True
|
||||
|
||||
if not auth_token:
|
||||
if not ctx.options.rotate_tokens:
|
||||
result.agents_skipped += 1
|
||||
_append_sync_error(
|
||||
result,
|
||||
agent=agent,
|
||||
board=board,
|
||||
message=(
|
||||
"Skipping agent: unable to read AUTH_TOKEN from TOOLS.md "
|
||||
"(run with rotate_tokens=true to re-key)."
|
||||
),
|
||||
)
|
||||
return None, False
|
||||
auth_token = await _rotate_agent_token(ctx.session, agent)
|
||||
|
||||
if agent.agent_token_hash and not verify_agent_token(
|
||||
auth_token,
|
||||
agent.agent_token_hash,
|
||||
):
|
||||
if ctx.options.rotate_tokens:
|
||||
auth_token = await _rotate_agent_token(ctx.session, agent)
|
||||
else:
|
||||
_append_sync_error(
|
||||
result,
|
||||
agent=agent,
|
||||
board=board,
|
||||
message=(
|
||||
"Warning: AUTH_TOKEN in TOOLS.md does not match backend "
|
||||
"token hash (agent auth may be broken)."
|
||||
),
|
||||
)
|
||||
return auth_token, False
|
||||
|
||||
|
||||
async def _sync_one_agent(
|
||||
ctx: _SyncContext,
|
||||
result: GatewayTemplatesSyncResult,
|
||||
agent: Agent,
|
||||
board: Board,
|
||||
) -> bool:
|
||||
auth_token, fatal = await _resolve_agent_auth_token(
|
||||
ctx,
|
||||
result,
|
||||
agent,
|
||||
board,
|
||||
agent_gateway_id=_gateway_agent_id(agent),
|
||||
)
|
||||
if fatal:
|
||||
return True
|
||||
if not auth_token:
|
||||
return False
|
||||
try:
|
||||
async def _do_provision() -> None:
|
||||
await provision_agent(
|
||||
agent,
|
||||
board,
|
||||
ctx.gateway,
|
||||
auth_token,
|
||||
ctx.options.user,
|
||||
action="update",
|
||||
force_bootstrap=ctx.options.force_bootstrap,
|
||||
reset_session=ctx.options.reset_sessions,
|
||||
)
|
||||
|
||||
await _with_gateway_retry(_do_provision, backoff=ctx.backoff)
|
||||
result.agents_updated += 1
|
||||
except TimeoutError as exc: # pragma: no cover - gateway/network dependent
|
||||
result.agents_skipped += 1
|
||||
_append_sync_error(result, agent=agent, board=board, message=str(exc))
|
||||
return True
|
||||
except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover
|
||||
result.agents_skipped += 1
|
||||
_append_sync_error(
|
||||
result,
|
||||
agent=agent,
|
||||
board=board,
|
||||
message=f"Failed to sync templates: {exc}",
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
async def _sync_main_agent(
|
||||
ctx: _SyncContext,
|
||||
result: GatewayTemplatesSyncResult,
|
||||
) -> bool:
|
||||
main_agent = (
|
||||
await Agent.objects.all()
|
||||
.filter(col(Agent.openclaw_session_id) == ctx.gateway.main_session_key)
|
||||
.first(ctx.session)
|
||||
)
|
||||
if main_agent is None:
|
||||
_append_sync_error(
|
||||
result,
|
||||
message=(
|
||||
"Gateway main agent record not found; "
|
||||
"skipping main agent template sync."
|
||||
),
|
||||
)
|
||||
return True
|
||||
try:
|
||||
main_gateway_agent_id = await _gateway_default_agent_id(
|
||||
ctx.config,
|
||||
fallback_session_key=ctx.gateway.main_session_key,
|
||||
backoff=ctx.backoff,
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
_append_sync_error(result, agent=main_agent, message=str(exc))
|
||||
return True
|
||||
if not main_gateway_agent_id:
|
||||
_append_sync_error(
|
||||
result,
|
||||
agent=main_agent,
|
||||
message="Unable to resolve gateway default agent id for main agent.",
|
||||
)
|
||||
return True
|
||||
|
||||
token, fatal = await _resolve_agent_auth_token(
|
||||
ctx,
|
||||
result,
|
||||
main_agent,
|
||||
board=None,
|
||||
agent_gateway_id=main_gateway_agent_id,
|
||||
)
|
||||
if fatal:
|
||||
return True
|
||||
if not token:
|
||||
_append_sync_error(
|
||||
result,
|
||||
agent=main_agent,
|
||||
message="Skipping main agent: unable to read AUTH_TOKEN from TOOLS.md.",
|
||||
)
|
||||
return True
|
||||
stop_sync = False
|
||||
try:
|
||||
async def _do_provision_main() -> None:
|
||||
await provision_main_agent(
|
||||
main_agent,
|
||||
ctx.gateway,
|
||||
token,
|
||||
ctx.options.user,
|
||||
action="update",
|
||||
force_bootstrap=ctx.options.force_bootstrap,
|
||||
reset_session=ctx.options.reset_sessions,
|
||||
)
|
||||
|
||||
await _with_gateway_retry(_do_provision_main, backoff=ctx.backoff)
|
||||
except TimeoutError as exc: # pragma: no cover - gateway/network dependent
|
||||
_append_sync_error(result, agent=main_agent, message=str(exc))
|
||||
stop_sync = True
|
||||
except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover
|
||||
_append_sync_error(
|
||||
result,
|
||||
agent=main_agent,
|
||||
message=f"Failed to sync main agent templates: {exc}",
|
||||
)
|
||||
else:
|
||||
result.main_updated = True
|
||||
return stop_sync
|
||||
|
||||
|
||||
async def sync_gateway_templates(
|
||||
session: AsyncSession,
|
||||
gateway: Gateway,
|
||||
options: GatewayTemplateSyncOptions,
|
||||
) -> GatewayTemplatesSyncResult:
|
||||
"""Synchronize AGENTS/TOOLS/etc templates to gateway-connected agents."""
|
||||
result = _base_result(
|
||||
gateway,
|
||||
include_main=options.include_main,
|
||||
reset_sessions=options.reset_sessions,
|
||||
)
|
||||
if not gateway.url:
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(message="Gateway URL is not configured for this gateway.")
|
||||
_append_sync_error(
|
||||
result,
|
||||
message="Gateway URL is not configured for this gateway.",
|
||||
)
|
||||
return result
|
||||
|
||||
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
|
||||
backoff = _GatewayBackoff(timeout_s=10 * 60)
|
||||
|
||||
# First, wait for the gateway to be reachable (e.g. while it is restarting).
|
||||
try:
|
||||
|
||||
async def _do_ping() -> object:
|
||||
return await openclaw_call("agents.list", config=client_config)
|
||||
|
||||
await backoff.run(_do_ping)
|
||||
except TimeoutError as exc:
|
||||
result.errors.append(GatewayTemplatesSyncError(message=str(exc)))
|
||||
return result
|
||||
except OpenClawGatewayError as exc:
|
||||
result.errors.append(GatewayTemplatesSyncError(message=str(exc)))
|
||||
ctx = _SyncContext(
|
||||
session=session,
|
||||
gateway=gateway,
|
||||
config=GatewayClientConfig(url=gateway.url, token=gateway.token),
|
||||
backoff=_GatewayBackoff(timeout_s=10 * 60),
|
||||
options=options,
|
||||
)
|
||||
if not await _ping_gateway(ctx, result):
|
||||
return result
|
||||
|
||||
boards = await Board.objects.filter_by(gateway_id=gateway.id).all(session)
|
||||
boards_by_id = {board.id: board for board in boards}
|
||||
if board_id is not None:
|
||||
board = boards_by_id.get(board_id)
|
||||
if board is None:
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
board_id=board_id,
|
||||
message="Board does not belong to this gateway.",
|
||||
)
|
||||
)
|
||||
return result
|
||||
boards_by_id = {board_id: board}
|
||||
|
||||
boards_by_id = _boards_by_id(boards, board_id=options.board_id)
|
||||
if boards_by_id is None:
|
||||
_append_sync_error(
|
||||
result,
|
||||
message="Board does not belong to this gateway.",
|
||||
)
|
||||
return result
|
||||
paused_board_ids = await _paused_board_ids(session, list(boards_by_id.keys()))
|
||||
|
||||
if boards_by_id:
|
||||
agents = await (
|
||||
Agent.objects.by_field_in("board_id", list(boards_by_id.keys()))
|
||||
@@ -358,251 +633,24 @@ async def sync_gateway_templates(
|
||||
else:
|
||||
agents = []
|
||||
|
||||
stop_sync = False
|
||||
for agent in agents:
|
||||
board = boards_by_id.get(agent.board_id) if agent.board_id is not None else None
|
||||
if board is None:
|
||||
result.agents_skipped += 1
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
agent_id=agent.id,
|
||||
agent_name=agent.name,
|
||||
board_id=agent.board_id,
|
||||
message="Skipping agent: board not found for agent.",
|
||||
)
|
||||
_append_sync_error(
|
||||
result,
|
||||
agent=agent,
|
||||
message="Skipping agent: board not found for agent.",
|
||||
)
|
||||
continue
|
||||
|
||||
if board.id in paused_board_ids:
|
||||
result.agents_skipped += 1
|
||||
continue
|
||||
stop_sync = await _sync_one_agent(ctx, result, agent, board)
|
||||
if stop_sync:
|
||||
break
|
||||
|
||||
agent_gateway_id = _gateway_agent_id(agent)
|
||||
try:
|
||||
auth_token = await _get_existing_auth_token(
|
||||
agent_gateway_id=agent_gateway_id,
|
||||
config=client_config,
|
||||
backoff=backoff,
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
agent_id=agent.id,
|
||||
agent_name=agent.name,
|
||||
board_id=board.id,
|
||||
message=str(exc),
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
if not auth_token:
|
||||
if not rotate_tokens:
|
||||
result.agents_skipped += 1
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
agent_id=agent.id,
|
||||
agent_name=agent.name,
|
||||
board_id=board.id,
|
||||
message="Skipping agent: unable to read AUTH_TOKEN from TOOLS.md (run with rotate_tokens=true to re-key).",
|
||||
)
|
||||
)
|
||||
continue
|
||||
raw_token = generate_agent_token()
|
||||
agent.agent_token_hash = hash_agent_token(raw_token)
|
||||
agent.updated_at = utcnow()
|
||||
session.add(agent)
|
||||
await session.commit()
|
||||
await session.refresh(agent)
|
||||
auth_token = raw_token
|
||||
|
||||
if agent.agent_token_hash and not verify_agent_token(auth_token, agent.agent_token_hash):
|
||||
# Do not block template sync on token drift; optionally re-key.
|
||||
if rotate_tokens:
|
||||
raw_token = generate_agent_token()
|
||||
agent.agent_token_hash = hash_agent_token(raw_token)
|
||||
agent.updated_at = utcnow()
|
||||
session.add(agent)
|
||||
await session.commit()
|
||||
await session.refresh(agent)
|
||||
auth_token = raw_token
|
||||
else:
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
agent_id=agent.id,
|
||||
agent_name=agent.name,
|
||||
board_id=board.id,
|
||||
message="Warning: AUTH_TOKEN in TOOLS.md does not match backend token hash (agent auth may be broken).",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
agent_item: Agent = agent
|
||||
board_item: Board = board
|
||||
auth_token_value: str = auth_token
|
||||
|
||||
async def _do_provision(
|
||||
agent_item: Agent = agent_item,
|
||||
board_item: Board = board_item,
|
||||
auth_token_value: str = auth_token_value,
|
||||
) -> None:
|
||||
await provision_agent(
|
||||
agent_item,
|
||||
board_item,
|
||||
gateway,
|
||||
auth_token_value,
|
||||
user,
|
||||
action="update",
|
||||
force_bootstrap=force_bootstrap,
|
||||
reset_session=reset_sessions,
|
||||
)
|
||||
|
||||
await _with_gateway_retry(_do_provision, backoff=backoff)
|
||||
result.agents_updated += 1
|
||||
except TimeoutError as exc: # pragma: no cover - gateway/network dependent
|
||||
result.agents_skipped += 1
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
agent_id=agent.id,
|
||||
agent_name=agent.name,
|
||||
board_id=board.id,
|
||||
message=str(exc),
|
||||
)
|
||||
)
|
||||
return result
|
||||
except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover
|
||||
result.agents_skipped += 1
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
agent_id=agent.id,
|
||||
agent_name=agent.name,
|
||||
board_id=board.id,
|
||||
message=f"Failed to sync templates: {exc}",
|
||||
)
|
||||
)
|
||||
|
||||
if include_main:
|
||||
main_agent = (
|
||||
await Agent.objects.all()
|
||||
.filter(col(Agent.openclaw_session_id) == gateway.main_session_key)
|
||||
.first(session)
|
||||
)
|
||||
if main_agent is None:
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
message="Gateway main agent record not found; skipping main agent template sync.",
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
try:
|
||||
main_gateway_agent_id = await _gateway_default_agent_id(
|
||||
client_config,
|
||||
fallback_session_key=gateway.main_session_key,
|
||||
backoff=backoff,
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
agent_id=main_agent.id,
|
||||
agent_name=main_agent.name,
|
||||
message=str(exc),
|
||||
)
|
||||
)
|
||||
return result
|
||||
if not main_gateway_agent_id:
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
agent_id=main_agent.id,
|
||||
agent_name=main_agent.name,
|
||||
message="Unable to resolve gateway default agent id for main agent.",
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
try:
|
||||
main_token = await _get_existing_auth_token(
|
||||
agent_gateway_id=main_gateway_agent_id,
|
||||
config=client_config,
|
||||
backoff=backoff,
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
agent_id=main_agent.id,
|
||||
agent_name=main_agent.name,
|
||||
message=str(exc),
|
||||
)
|
||||
)
|
||||
return result
|
||||
if not main_token:
|
||||
if rotate_tokens:
|
||||
raw_token = generate_agent_token()
|
||||
main_agent.agent_token_hash = hash_agent_token(raw_token)
|
||||
main_agent.updated_at = utcnow()
|
||||
session.add(main_agent)
|
||||
await session.commit()
|
||||
await session.refresh(main_agent)
|
||||
main_token = raw_token
|
||||
else:
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
agent_id=main_agent.id,
|
||||
agent_name=main_agent.name,
|
||||
message="Skipping main agent: unable to read AUTH_TOKEN from TOOLS.md.",
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
if main_agent.agent_token_hash and not verify_agent_token(
|
||||
main_token, main_agent.agent_token_hash
|
||||
):
|
||||
if rotate_tokens:
|
||||
raw_token = generate_agent_token()
|
||||
main_agent.agent_token_hash = hash_agent_token(raw_token)
|
||||
main_agent.updated_at = utcnow()
|
||||
session.add(main_agent)
|
||||
await session.commit()
|
||||
await session.refresh(main_agent)
|
||||
main_token = raw_token
|
||||
else:
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
agent_id=main_agent.id,
|
||||
agent_name=main_agent.name,
|
||||
message="Warning: AUTH_TOKEN in TOOLS.md does not match backend token hash (main agent auth may be broken).",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
async def _do_provision_main() -> None:
|
||||
await provision_main_agent(
|
||||
main_agent,
|
||||
gateway,
|
||||
main_token,
|
||||
user,
|
||||
action="update",
|
||||
force_bootstrap=force_bootstrap,
|
||||
reset_session=reset_sessions,
|
||||
)
|
||||
|
||||
await _with_gateway_retry(_do_provision_main, backoff=backoff)
|
||||
result.main_updated = True
|
||||
except TimeoutError as exc: # pragma: no cover - gateway/network dependent
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
agent_id=main_agent.id,
|
||||
agent_name=main_agent.name,
|
||||
message=str(exc),
|
||||
)
|
||||
)
|
||||
return result
|
||||
except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover
|
||||
result.errors.append(
|
||||
GatewayTemplatesSyncError(
|
||||
agent_id=main_agent.id,
|
||||
agent_name=main_agent.name,
|
||||
message=f"Failed to sync main agent templates: {exc}",
|
||||
)
|
||||
)
|
||||
|
||||
if not stop_sync and options.include_main:
|
||||
await _sync_main_agent(ctx, result)
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user