refactor: enhance docstrings for clarity and consistency across multiple files

This commit is contained in:
Abhimanyu Saharan
2026-02-09 16:23:41 +05:30
parent 7ca1899d9f
commit 7706943209
28 changed files with 1829 additions and 932 deletions

View File

@@ -1,5 +1,4 @@
"""Gateway-facing agent provisioning and cleanup helpers."""
# ruff: noqa: EM101, TRY003
from __future__ import annotations
@@ -176,7 +175,8 @@ def _heartbeat_template_name(agent: Agent) -> str:
def _workspace_path(agent: Agent, workspace_root: str) -> str:
if not workspace_root:
raise ValueError("gateway_workspace_root is required")
msg = "gateway_workspace_root is required"
raise ValueError(msg)
root = workspace_root.rstrip("/")
# Use agent key derived from session key when possible. This prevents collisions for
# lead agents (session key includes board id) even if multiple boards share the same
@@ -227,9 +227,11 @@ def _build_context(
user: User | None,
) -> dict[str, str]:
if not gateway.workspace_root:
raise ValueError("gateway_workspace_root is required")
msg = "gateway_workspace_root is required"
raise ValueError(msg)
if not gateway.main_session_key:
raise ValueError("gateway_main_session_key is required")
msg = "gateway_main_session_key is required"
raise ValueError(msg)
agent_id = str(agent.id)
workspace_root = gateway.workspace_root
workspace_path = _workspace_path(agent, workspace_root)
@@ -485,15 +487,18 @@ async def _patch_gateway_agent_list(
) -> None:
cfg = await openclaw_call("config.get", config=config)
if not isinstance(cfg, dict):
raise OpenClawGatewayError("config.get returned invalid payload")
msg = "config.get returned invalid payload"
raise OpenClawGatewayError(msg)
base_hash = cfg.get("hash")
data = cfg.get("config") or cfg.get("parsed") or {}
if not isinstance(data, dict):
raise OpenClawGatewayError("config.get returned invalid config")
msg = "config.get returned invalid config"
raise OpenClawGatewayError(msg)
agents = data.get("agents") or {}
lst = agents.get("list") or []
if not isinstance(lst, list):
raise OpenClawGatewayError("config agents.list is not a list")
msg = "config agents.list is not a list"
raise OpenClawGatewayError(msg)
updated = False
new_list: list[dict[str, Any]] = []
@@ -528,19 +533,23 @@ async def patch_gateway_agent_heartbeats( # noqa: C901
Each entry is (agent_id, workspace_path, heartbeat_dict).
"""
if not gateway.url:
raise OpenClawGatewayError("Gateway url is required")
msg = "Gateway url is required"
raise OpenClawGatewayError(msg)
config = GatewayClientConfig(url=gateway.url, token=gateway.token)
cfg = await openclaw_call("config.get", config=config)
if not isinstance(cfg, dict):
raise OpenClawGatewayError("config.get returned invalid payload")
msg = "config.get returned invalid payload"
raise OpenClawGatewayError(msg)
base_hash = cfg.get("hash")
data = cfg.get("config") or cfg.get("parsed") or {}
if not isinstance(data, dict):
raise OpenClawGatewayError("config.get returned invalid config")
msg = "config.get returned invalid config"
raise OpenClawGatewayError(msg)
agents_section = data.get("agents") or {}
lst = agents_section.get("list") or []
if not isinstance(lst, list):
raise OpenClawGatewayError("config agents.list is not a list")
msg = "config agents.list is not a list"
raise OpenClawGatewayError(msg)
entry_by_id: dict[str, tuple[str, dict[str, Any]]] = {
agent_id: (workspace_path, heartbeat)
@@ -581,7 +590,8 @@ async def patch_gateway_agent_heartbeats( # noqa: C901
async def sync_gateway_agent_heartbeats(gateway: Gateway, agents: list[Agent]) -> None:
"""Sync current Agent.heartbeat_config values to the gateway config."""
if not gateway.workspace_root:
raise OpenClawGatewayError("gateway workspace_root is required")
msg = "gateway workspace_root is required"
raise OpenClawGatewayError(msg)
entries: list[tuple[str, str, dict[str, Any]]] = []
for agent in agents:
agent_id = _agent_key(agent)
@@ -599,15 +609,18 @@ async def _remove_gateway_agent_list(
) -> None:
cfg = await openclaw_call("config.get", config=config)
if not isinstance(cfg, dict):
raise OpenClawGatewayError("config.get returned invalid payload")
msg = "config.get returned invalid payload"
raise OpenClawGatewayError(msg)
base_hash = cfg.get("hash")
data = cfg.get("config") or cfg.get("parsed") or {}
if not isinstance(data, dict):
raise OpenClawGatewayError("config.get returned invalid config")
msg = "config.get returned invalid config"
raise OpenClawGatewayError(msg)
agents = data.get("agents") or {}
lst = agents.get("list") or []
if not isinstance(lst, list):
raise OpenClawGatewayError("config agents.list is not a list")
msg = "config agents.list is not a list"
raise OpenClawGatewayError(msg)
new_list = [
entry
@@ -658,7 +671,8 @@ async def provision_agent( # noqa: C901, PLR0912, PLR0913
if not gateway.url:
return
if not gateway.workspace_root:
raise ValueError("gateway_workspace_root is required")
msg = "gateway_workspace_root is required"
raise ValueError(msg)
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
session_key = _session_key(agent)
await ensure_session(session_key, config=client_config, label=agent.name)
@@ -734,7 +748,8 @@ async def provision_main_agent( # noqa: C901, PLR0912, PLR0913
if not gateway.url:
return
if not gateway.main_session_key:
raise ValueError("gateway main_session_key is required")
msg = "gateway main_session_key is required"
raise ValueError(msg)
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
await ensure_session(
gateway.main_session_key, config=client_config, label="Main Agent",
@@ -745,7 +760,8 @@ async def provision_main_agent( # noqa: C901, PLR0912, PLR0913
fallback_session_key=gateway.main_session_key,
)
if not agent_id:
raise OpenClawGatewayError("Unable to resolve gateway main agent id")
msg = "Unable to resolve gateway main agent id"
raise OpenClawGatewayError(msg)
context = _build_main_context(agent, gateway, auth_token, user)
supported = set(await _supported_gateway_files(client_config))
@@ -796,7 +812,8 @@ async def cleanup_agent(
if not gateway.url:
return None
if not gateway.workspace_root:
raise ValueError("gateway_workspace_root is required")
msg = "gateway_workspace_root is required"
raise ValueError(msg)
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
agent_id = _agent_key(agent)

View File

@@ -1,7 +1,8 @@
"""Helpers for assembling board-group snapshot view models."""
from __future__ import annotations
from collections import defaultdict
from typing import Any
from uuid import UUID
from sqlalchemy import case, func
@@ -22,48 +23,67 @@ from app.schemas.view_models import (
_STATUS_ORDER = {"in_progress": 0, "review": 1, "inbox": 2, "done": 3}
_PRIORITY_ORDER = {"high": 0, "medium": 1, "low": 2}
_RUNTIME_TYPE_REFERENCES = (UUID, AsyncSession)
def _status_weight_expr() -> Any:
def _status_weight_expr() -> object:
"""Return a SQL expression that sorts task statuses by configured order."""
whens = [(col(Task.status) == key, weight) for key, weight in _STATUS_ORDER.items()]
return case(*whens, else_=99)
def _priority_weight_expr() -> Any:
whens = [(col(Task.priority) == key, weight) for key, weight in _PRIORITY_ORDER.items()]
def _priority_weight_expr() -> object:
"""Return a SQL expression that sorts task priorities by configured order."""
whens = [
(col(Task.priority) == key, weight)
for key, weight in _PRIORITY_ORDER.items()
]
return case(*whens, else_=99)
async def build_group_snapshot(
async def _boards_for_group(
session: AsyncSession,
*,
group: BoardGroup,
group_id: UUID,
exclude_board_id: UUID | None = None,
include_done: bool = False,
per_board_task_limit: int = 5,
) -> BoardGroupSnapshot:
statement = Board.objects.filter_by(board_group_id=group.id).statement
) -> list[Board]:
"""Return boards belonging to a board group with optional exclusion."""
statement = Board.objects.filter_by(board_group_id=group_id).statement
if exclude_board_id is not None:
statement = statement.where(col(Board.id) != exclude_board_id)
boards = list(await session.exec(statement.order_by(func.lower(col(Board.name)).asc())))
if not boards:
return BoardGroupSnapshot(group=BoardGroupRead.model_validate(group, from_attributes=True))
return list(
await session.exec(
statement.order_by(func.lower(col(Board.name)).asc()),
),
)
boards_by_id = {board.id: board for board in boards}
board_ids = list(boards_by_id.keys())
async def _task_counts_by_board(
session: AsyncSession,
board_ids: list[UUID],
) -> dict[UUID, dict[str, int]]:
"""Return per-board task counts keyed by task status."""
task_counts: dict[UUID, dict[str, int]] = defaultdict(lambda: defaultdict(int))
for board_id, status_value, total in list(
await session.exec(
select(col(Task.board_id), col(Task.status), func.count(col(Task.id)))
.where(col(Task.board_id).in_(board_ids))
.group_by(col(Task.board_id), col(Task.status))
)
.group_by(col(Task.board_id), col(Task.status)),
),
):
if board_id is None:
continue
task_counts[board_id][str(status_value)] = int(total or 0)
return task_counts
async def _ordered_tasks_for_boards(
session: AsyncSession,
board_ids: list[UUID],
*,
include_done: bool,
) -> list[Task]:
"""Return sorted tasks for boards, optionally excluding completed tasks."""
task_statement = select(Task).where(col(Task.board_id).in_(board_ids))
if not include_done:
task_statement = task_statement.where(col(Task.status) != "done")
@@ -74,62 +94,116 @@ async def build_group_snapshot(
col(Task.updated_at).desc(),
col(Task.created_at).desc(),
)
tasks = list(await session.exec(task_statement))
return list(await session.exec(task_statement))
assigned_ids = {task.assigned_agent_id for task in tasks if task.assigned_agent_id is not None}
agent_name_by_id: dict[UUID, str] = {}
if assigned_ids:
for agent_id, name in list(
async def _agent_names(
session: AsyncSession,
tasks: list[Task],
) -> dict[UUID, str]:
"""Return agent names keyed by assigned agent ids in the provided tasks."""
assigned_ids = {
task.assigned_agent_id
for task in tasks
if task.assigned_agent_id is not None
}
if not assigned_ids:
return {}
return dict(
list(
await session.exec(
select(col(Agent.id), col(Agent.name)).where(col(Agent.id).in_(assigned_ids))
)
):
agent_name_by_id[agent_id] = name
select(col(Agent.id), col(Agent.name)).where(
col(Agent.id).in_(assigned_ids),
),
),
),
)
def _task_summaries_by_board(
*,
boards_by_id: dict[UUID, Board],
tasks: list[Task],
agent_name_by_id: dict[UUID, str],
per_board_task_limit: int,
) -> dict[UUID, list[BoardGroupTaskSummary]]:
"""Build limited per-board task summary lists."""
tasks_by_board: dict[UUID, list[BoardGroupTaskSummary]] = defaultdict(list)
if per_board_task_limit > 0:
for task in tasks:
if task.board_id is None:
continue
current = tasks_by_board[task.board_id]
if len(current) >= per_board_task_limit:
continue
board = boards_by_id.get(task.board_id)
if board is None:
continue
current.append(
BoardGroupTaskSummary(
id=task.id,
board_id=task.board_id,
board_name=board.name,
title=task.title,
status=task.status,
priority=task.priority,
assigned_agent_id=task.assigned_agent_id,
assignee=(
agent_name_by_id.get(task.assigned_agent_id)
if task.assigned_agent_id is not None
else None
),
due_at=task.due_at,
in_progress_at=task.in_progress_at,
created_at=task.created_at,
updated_at=task.updated_at,
)
)
snapshots: list[BoardGroupBoardSnapshot] = []
for board in boards:
board_read = BoardRead.model_validate(board, from_attributes=True)
counts = dict(task_counts.get(board.id, {}))
snapshots.append(
BoardGroupBoardSnapshot(
board=board_read,
task_counts=counts,
tasks=tasks_by_board.get(board.id, []),
)
if per_board_task_limit <= 0:
return tasks_by_board
for task in tasks:
if task.board_id is None:
continue
current = tasks_by_board[task.board_id]
if len(current) >= per_board_task_limit:
continue
board = boards_by_id.get(task.board_id)
if board is None:
continue
current.append(
BoardGroupTaskSummary(
id=task.id,
board_id=task.board_id,
board_name=board.name,
title=task.title,
status=task.status,
priority=task.priority,
assigned_agent_id=task.assigned_agent_id,
assignee=(
agent_name_by_id.get(task.assigned_agent_id)
if task.assigned_agent_id is not None
else None
),
due_at=task.due_at,
in_progress_at=task.in_progress_at,
created_at=task.created_at,
updated_at=task.updated_at,
),
)
return tasks_by_board
async def build_group_snapshot(
session: AsyncSession,
*,
group: BoardGroup,
exclude_board_id: UUID | None = None,
include_done: bool = False,
per_board_task_limit: int = 5,
) -> BoardGroupSnapshot:
"""Build a board-group snapshot with board/task summaries."""
boards = await _boards_for_group(
session,
group_id=group.id,
exclude_board_id=exclude_board_id,
)
if not boards:
return BoardGroupSnapshot(
group=BoardGroupRead.model_validate(group, from_attributes=True),
)
boards_by_id = {board.id: board for board in boards}
board_ids = list(boards_by_id.keys())
task_counts = await _task_counts_by_board(session, board_ids)
tasks = await _ordered_tasks_for_boards(
session,
board_ids,
include_done=include_done,
)
agent_name_by_id = await _agent_names(session, tasks)
tasks_by_board = _task_summaries_by_board(
boards_by_id=boards_by_id,
tasks=tasks,
agent_name_by_id=agent_name_by_id,
per_board_task_limit=per_board_task_limit,
)
snapshots = [
BoardGroupBoardSnapshot(
board=BoardRead.model_validate(board, from_attributes=True),
task_counts=dict(task_counts.get(board.id, {})),
tasks=tasks_by_board.get(board.id, []),
)
for board in boards
]
return BoardGroupSnapshot(
group=BoardGroupRead.model_validate(group, from_attributes=True),
boards=snapshots,
@@ -144,6 +218,7 @@ async def build_board_group_snapshot(
include_done: bool = False,
per_board_task_limit: int = 5,
) -> BoardGroupSnapshot:
"""Build a board-group snapshot anchored to a board context."""
if not board.board_group_id:
return BoardGroupSnapshot(group=None, boards=[])
group = await BoardGroup.objects.by_id(board.board_group_id).first(session)

View File

@@ -1,5 +1,4 @@
"""Organization membership and board-access service helpers."""
# ruff: noqa: D101, D103
from __future__ import annotations
@@ -38,19 +37,24 @@ ROLE_RANK = {"member": 0, "admin": 1, "owner": 2}
@dataclass(frozen=True)
class OrganizationContext:
"""Resolved organization and membership for the active user."""
organization: Organization
member: OrganizationMember
def is_org_admin(member: OrganizationMember) -> bool:
"""Return whether a member has admin-level organization privileges."""
return member.role in ADMIN_ROLES
async def get_default_org(session: AsyncSession) -> Organization | None:
"""Return the default personal organization if it exists."""
return await Organization.objects.filter_by(name=DEFAULT_ORG_NAME).first(session)
async def ensure_default_org(session: AsyncSession) -> Organization:
"""Ensure and return the default personal organization."""
org = await get_default_org(session)
if org is not None:
return org
@@ -67,6 +71,7 @@ async def get_member(
user_id: UUID,
organization_id: UUID,
) -> OrganizationMember | None:
"""Fetch a membership by user id and organization id."""
return await OrganizationMember.objects.filter_by(
user_id=user_id,
organization_id=organization_id,
@@ -76,6 +81,7 @@ async def get_member(
async def get_first_membership(
session: AsyncSession, user_id: UUID,
) -> OrganizationMember | None:
"""Return the oldest membership for a user, if any."""
return (
await OrganizationMember.objects.filter_by(user_id=user_id)
.order_by(col(OrganizationMember.created_at).asc())
@@ -89,6 +95,7 @@ async def set_active_organization(
user: User,
organization_id: UUID,
) -> OrganizationMember:
"""Set a user's active organization and return the membership."""
member = await get_member(session, user_id=user.id, organization_id=organization_id)
if member is None:
raise HTTPException(
@@ -105,6 +112,7 @@ async def get_active_membership(
session: AsyncSession,
user: User,
) -> OrganizationMember | None:
"""Resolve and normalize the user's currently active membership."""
db_user = await User.objects.by_id(user.id).first(session)
if db_user is None:
db_user = user
@@ -151,6 +159,7 @@ async def accept_invite(
invite: OrganizationInvite,
user: User,
) -> OrganizationMember:
"""Accept an invite and create membership plus scoped board access rows."""
now = utcnow()
member = OrganizationMember(
organization_id=invite.organization_id,
@@ -200,6 +209,7 @@ async def accept_invite(
async def ensure_member_for_user(
session: AsyncSession, user: User,
) -> OrganizationMember:
"""Ensure a user has some membership, creating one if necessary."""
existing = await get_active_membership(session, user)
if existing is not None:
return existing
@@ -237,10 +247,12 @@ async def ensure_member_for_user(
def member_all_boards_read(member: OrganizationMember) -> bool:
"""Return whether the member has organization-wide read access."""
return member.all_boards_read or member.all_boards_write
def member_all_boards_write(member: OrganizationMember) -> bool:
"""Return whether the member has organization-wide write access."""
return member.all_boards_write
@@ -251,6 +263,7 @@ async def has_board_access(
board: Board,
write: bool,
) -> bool:
"""Return whether a member has board access for the requested mode."""
if member.organization_id != board.organization_id:
return False
if write:
@@ -276,6 +289,7 @@ async def require_board_access(
board: Board,
write: bool,
) -> OrganizationMember:
"""Require board access for a user and return matching membership."""
member = await get_member(
session, user_id=user.id, organization_id=board.organization_id,
)
@@ -293,6 +307,7 @@ async def require_board_access(
def board_access_filter(
member: OrganizationMember, *, write: bool,
) -> ColumnElement[bool]:
"""Build a SQL filter expression for boards visible to a member."""
if write and member_all_boards_write(member):
return col(Board.organization_id) == member.organization_id
if not write and member_all_boards_read(member):
@@ -320,6 +335,7 @@ async def list_accessible_board_ids(
member: OrganizationMember,
write: bool,
) -> list[UUID]:
"""List board ids accessible to a member for read or write mode."""
if (write and member_all_boards_write(member)) or (
not write and member_all_boards_read(member)
):
@@ -354,6 +370,7 @@ async def apply_member_access_update(
member: OrganizationMember,
update: OrganizationMemberAccessUpdate,
) -> None:
"""Replace explicit member board-access rows from an access update."""
now = utcnow()
member.all_boards_read = update.all_boards_read
member.all_boards_write = update.all_boards_write
@@ -390,6 +407,7 @@ async def apply_invite_board_access(
invite: OrganizationInvite,
entries: Iterable[OrganizationBoardAccessSpec],
) -> None:
"""Replace explicit invite board-access rows for an invite."""
await crud.delete_where(
session,
OrganizationInviteBoardAccess,
@@ -414,10 +432,12 @@ async def apply_invite_board_access(
def normalize_invited_email(email: str) -> str:
"""Normalize an invited email address for storage/comparison."""
return email.strip().lower()
def normalize_role(role: str) -> str:
"""Normalize a role string and default empty values to `member`."""
return role.strip().lower() or "member"
@@ -433,6 +453,7 @@ async def apply_invite_to_member(
member: OrganizationMember,
invite: OrganizationInvite,
) -> None:
"""Apply invite role/access grants onto an existing organization member."""
now = utcnow()
member_changed = False
invite_role = normalize_role(invite.role or "member")

View File

@@ -1,3 +1,5 @@
"""Task-dependency helpers for validation, querying, and replacement."""
from __future__ import annotations
from collections import defaultdict
@@ -14,6 +16,7 @@ from app.models.task_dependencies import TaskDependency
from app.models.tasks import Task
DONE_STATUS: Final[str] = "done"
_RUNTIME_TYPE_REFERENCES = (UUID, AsyncSession, Mapping, Sequence)
def _dedupe_uuid_list(values: Sequence[UUID]) -> list[UUID]:
@@ -34,6 +37,7 @@ async def dependency_ids_by_task_id(
board_id: UUID,
task_ids: Sequence[UUID],
) -> dict[UUID, list[UUID]]:
"""Return dependency ids keyed by task id for tasks on a board."""
if not task_ids:
return {}
rows = list(
@@ -41,8 +45,8 @@ async def dependency_ids_by_task_id(
select(col(TaskDependency.task_id), col(TaskDependency.depends_on_task_id))
.where(col(TaskDependency.board_id) == board_id)
.where(col(TaskDependency.task_id).in_(task_ids))
.order_by(col(TaskDependency.created_at).asc())
)
.order_by(col(TaskDependency.created_at).asc()),
),
)
mapping: dict[UUID, list[UUID]] = defaultdict(list)
for task_id, depends_on_task_id in rows:
@@ -56,16 +60,17 @@ async def dependency_status_by_id(
board_id: UUID,
dependency_ids: Sequence[UUID],
) -> dict[UUID, str]:
"""Return dependency status values keyed by dependency task id."""
if not dependency_ids:
return {}
rows = list(
await session.exec(
select(col(Task.id), col(Task.status))
.where(col(Task.board_id) == board_id)
.where(col(Task.id).in_(dependency_ids))
)
.where(col(Task.id).in_(dependency_ids)),
),
)
return {task_id: status_value for task_id, status_value in rows}
return dict(rows)
def blocked_by_dependency_ids(
@@ -73,11 +78,12 @@ def blocked_by_dependency_ids(
dependency_ids: Sequence[UUID],
status_by_id: Mapping[UUID, str],
) -> list[UUID]:
blocked: list[UUID] = []
for dep_id in dependency_ids:
if status_by_id.get(dep_id) != DONE_STATUS:
blocked.append(dep_id)
return blocked
"""Return dependency ids that are not yet in the done status."""
return [
dep_id
for dep_id in dependency_ids
if status_by_id.get(dep_id) != DONE_STATUS
]
async def blocked_by_for_task(
@@ -87,6 +93,7 @@ async def blocked_by_for_task(
task_id: UUID,
dependency_ids: Sequence[UUID] | None = None,
) -> list[UUID]:
"""Return unresolved dependency ids for the provided task."""
dep_ids = list(dependency_ids or [])
if dependency_ids is None:
deps_map = await dependency_ids_by_task_id(
@@ -97,11 +104,16 @@ async def blocked_by_for_task(
dep_ids = deps_map.get(task_id, [])
if not dep_ids:
return []
status_by_id = await dependency_status_by_id(session, board_id=board_id, dependency_ids=dep_ids)
status_by_id = await dependency_status_by_id(
session,
board_id=board_id,
dependency_ids=dep_ids,
)
return blocked_by_dependency_ids(dependency_ids=dep_ids, status_by_id=status_by_id)
def _has_cycle(nodes: Sequence[UUID], edges: Mapping[UUID, set[UUID]]) -> bool:
"""Detect cycles in a directed dependency graph."""
visited: set[UUID] = set()
in_stack: set[UUID] = set()
@@ -118,10 +130,7 @@ def _has_cycle(nodes: Sequence[UUID], edges: Mapping[UUID, set[UUID]]) -> bool:
in_stack.remove(node)
return False
for node in nodes:
if dfs(node):
return True
return False
return any(dfs(node) for node in nodes)
async def validate_dependency_update(
@@ -131,6 +140,7 @@ async def validate_dependency_update(
task_id: UUID,
depends_on_task_ids: Sequence[UUID],
) -> list[UUID]:
"""Validate a dependency update and return normalized dependency ids."""
normalized = _dedupe_uuid_list(depends_on_task_ids)
if task_id in normalized:
raise HTTPException(
@@ -145,8 +155,8 @@ async def validate_dependency_update(
await session.exec(
select(col(Task.id))
.where(col(Task.board_id) == board_id)
.where(col(Task.id).in_(normalized))
)
.where(col(Task.id).in_(normalized)),
),
)
missing = [dep_id for dep_id in normalized if dep_id not in existing_ids]
if missing:
@@ -159,13 +169,18 @@ async def validate_dependency_update(
)
# Ensure the dependency graph is acyclic after applying the update.
task_ids = list(await session.exec(select(col(Task.id)).where(col(Task.board_id) == board_id)))
task_ids = list(
await session.exec(
select(col(Task.id)).where(col(Task.board_id) == board_id),
),
)
rows = list(
await session.exec(
select(col(TaskDependency.task_id), col(TaskDependency.depends_on_task_id)).where(
col(TaskDependency.board_id) == board_id
)
)
select(
col(TaskDependency.task_id),
col(TaskDependency.depends_on_task_id),
).where(col(TaskDependency.board_id) == board_id),
),
)
edges: dict[UUID, set[UUID]] = defaultdict(set)
for src, dst in rows:
@@ -188,6 +203,7 @@ async def replace_task_dependencies(
task_id: UUID,
depends_on_task_ids: Sequence[UUID],
) -> list[UUID]:
"""Replace dependencies for a task and return the normalized dependency ids."""
normalized = await validate_dependency_update(
session,
board_id=board_id,
@@ -207,7 +223,7 @@ async def replace_task_dependencies(
board_id=board_id,
task_id=task_id,
depends_on_task_id=dep_id,
)
),
)
return normalized
@@ -218,9 +234,10 @@ async def dependent_task_ids(
board_id: UUID,
dependency_task_id: UUID,
) -> list[UUID]:
"""Return task ids that depend on the provided dependency task id."""
rows = await session.exec(
select(col(TaskDependency.task_id))
.where(col(TaskDependency.board_id) == board_id)
.where(col(TaskDependency.depends_on_task_id) == dependency_task_id)
.where(col(TaskDependency.depends_on_task_id) == dependency_task_id),
)
return list(rows)

View File

@@ -1,9 +1,12 @@
"""Gateway template synchronization orchestration."""
from __future__ import annotations
import asyncio
import random
import re
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import TypeVar
from uuid import UUID, uuid4
@@ -11,7 +14,11 @@ from sqlalchemy import func
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.agent_tokens import generate_agent_token, hash_agent_token, verify_agent_token
from app.core.agent_tokens import (
generate_agent_token,
hash_agent_token,
verify_agent_token,
)
from app.core.time import utcnow
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, openclaw_call
@@ -49,6 +56,31 @@ _TRANSIENT_GATEWAY_ERROR_MARKERS = (
)
T = TypeVar("T")
_SECURE_RANDOM = random.SystemRandom()
_RUNTIME_TYPE_REFERENCES = (Awaitable, Callable, AsyncSession, Gateway, User, UUID)
@dataclass(frozen=True)
class GatewayTemplateSyncOptions:
"""Runtime options controlling gateway template synchronization."""
user: User | None
include_main: bool = True
reset_sessions: bool = False
rotate_tokens: bool = False
force_bootstrap: bool = False
board_id: UUID | None = None
@dataclass(frozen=True)
class _SyncContext:
"""Shared state passed to sync helper functions."""
session: AsyncSession
gateway: Gateway
config: GatewayClientConfig
backoff: _GatewayBackoff
options: GatewayTemplateSyncOptions
def _slugify(value: str) -> str:
@@ -70,7 +102,10 @@ def _is_transient_gateway_error(exc: Exception) -> bool:
def _gateway_timeout_message(exc: OpenClawGatewayError) -> str:
return f"Gateway unreachable after 10 minutes (template sync timeout). Last error: {exc}"
return (
"Gateway unreachable after 10 minutes (template sync timeout). "
f"Last error: {exc}"
)
class _GatewayBackoff:
@@ -91,16 +126,25 @@ class _GatewayBackoff:
def reset(self) -> None:
self._delay_s = self._base_delay_s
async def _attempt(
self,
fn: Callable[[], Awaitable[T]],
) -> tuple[T | None, OpenClawGatewayError | None]:
try:
return await fn(), None
except OpenClawGatewayError as exc:
return None, exc
async def run(self, fn: Callable[[], Awaitable[T]]) -> T:
# Use per-call deadlines so long-running syncs can still tolerate a later
# gateway restart without having an already-expired retry window.
deadline_s = asyncio.get_running_loop().time() + self._timeout_s
while True:
try:
value = await fn()
except OpenClawGatewayError as exc:
value, error = await self._attempt(fn)
if error is not None:
exc = error
if not _is_transient_gateway_error(exc):
raise
raise exc
now = asyncio.get_running_loop().time()
remaining = deadline_s - now
if remaining <= 0:
@@ -108,13 +152,16 @@ class _GatewayBackoff:
sleep_s = min(self._delay_s, remaining)
if self._jitter:
sleep_s *= 1.0 + random.uniform(-self._jitter, self._jitter)
sleep_s *= 1.0 + _SECURE_RANDOM.uniform(
-self._jitter,
self._jitter,
)
sleep_s = max(0.0, min(sleep_s, remaining))
await asyncio.sleep(sleep_s)
self._delay_s = min(self._delay_s * 2.0, self._max_delay_s)
else:
self.reset()
return value
continue
self.reset()
return value
async def _with_gateway_retry(
@@ -138,23 +185,25 @@ def _agent_id_from_session_key(session_key: str | None) -> str | None:
return agent_id or None
def _extract_agent_id(payload: object) -> str | None:
def _from_list(items: object) -> str | None:
if not isinstance(items, list):
return None
for item in items:
if isinstance(item, str) and item.strip():
return item.strip()
if not isinstance(item, dict):
continue
for key in ("id", "agentId", "agent_id"):
raw = item.get(key)
if isinstance(raw, str) and raw.strip():
return raw.strip()
def _extract_agent_id_from_list(items: object) -> str | None:
if not isinstance(items, list):
return None
for item in items:
if isinstance(item, str) and item.strip():
return item.strip()
if not isinstance(item, dict):
continue
for key in ("id", "agentId", "agent_id"):
raw = item.get(key)
if isinstance(raw, str) and raw.strip():
return raw.strip()
return None
def _extract_agent_id(payload: object) -> str | None:
"""Extract a default gateway agent id from common list payload shapes."""
if isinstance(payload, list):
return _from_list(payload)
return _extract_agent_id_from_list(payload)
if not isinstance(payload, dict):
return None
for key in ("defaultId", "default_id", "defaultAgentId", "default_agent_id"):
@@ -162,7 +211,7 @@ def _extract_agent_id(payload: object) -> str | None:
if isinstance(raw, str) and raw.strip():
return raw.strip()
for key in ("agents", "items", "list", "data"):
agent_id = _from_list(payload.get(key))
agent_id = _extract_agent_id_from_list(payload.get(key))
if agent_id:
return agent_id
return None
@@ -212,9 +261,6 @@ async def _get_agent_file(
if isinstance(payload, str):
return payload
if isinstance(payload, dict):
# Common shapes:
# - {"name": "...", "content": "..."}
# - {"file": {"name": "...", "content": "..." }}
content = payload.get("content")
if isinstance(content, str):
return content
@@ -291,18 +337,53 @@ async def _paused_board_ids(session: AsyncSession, board_ids: list[UUID]) -> set
return paused
async def sync_gateway_templates(
session: AsyncSession,
def _append_sync_error(
result: GatewayTemplatesSyncResult,
*,
message: str,
agent: Agent | None = None,
board: Board | None = None,
) -> None:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=agent.id if agent else None,
agent_name=agent.name if agent else None,
board_id=board.id if board else None,
message=message,
),
)
async def _rotate_agent_token(session: AsyncSession, agent: Agent) -> str:
token = generate_agent_token()
agent.agent_token_hash = hash_agent_token(token)
agent.updated_at = utcnow()
session.add(agent)
await session.commit()
await session.refresh(agent)
return token
async def _ping_gateway(ctx: _SyncContext, result: GatewayTemplatesSyncResult) -> bool:
try:
async def _do_ping() -> object:
return await openclaw_call("agents.list", config=ctx.config)
await ctx.backoff.run(_do_ping)
except (TimeoutError, OpenClawGatewayError) as exc:
_append_sync_error(result, message=str(exc))
return False
else:
return True
def _base_result(
gateway: Gateway,
*,
user: User | None,
include_main: bool = True,
reset_sessions: bool = False,
rotate_tokens: bool = False,
force_bootstrap: bool = False,
board_id: UUID | None = None,
include_main: bool,
reset_sessions: bool,
) -> GatewayTemplatesSyncResult:
result = GatewayTemplatesSyncResult(
return GatewayTemplatesSyncResult(
gateway_id=gateway.id,
include_main=include_main,
reset_sessions=reset_sessions,
@@ -310,45 +391,239 @@ async def sync_gateway_templates(
agents_skipped=0,
main_updated=False,
)
def _boards_by_id(
boards: list[Board],
*,
board_id: UUID | None,
) -> dict[UUID, Board] | None:
boards_by_id = {board.id: board for board in boards}
if board_id is None:
return boards_by_id
board = boards_by_id.get(board_id)
if board is None:
return None
return {board_id: board}
async def _resolve_agent_auth_token(
ctx: _SyncContext,
result: GatewayTemplatesSyncResult,
agent: Agent,
board: Board | None,
*,
agent_gateway_id: str,
) -> tuple[str | None, bool]:
try:
auth_token = await _get_existing_auth_token(
agent_gateway_id=agent_gateway_id,
config=ctx.config,
backoff=ctx.backoff,
)
except TimeoutError as exc:
_append_sync_error(result, agent=agent, board=board, message=str(exc))
return None, True
if not auth_token:
if not ctx.options.rotate_tokens:
result.agents_skipped += 1
_append_sync_error(
result,
agent=agent,
board=board,
message=(
"Skipping agent: unable to read AUTH_TOKEN from TOOLS.md "
"(run with rotate_tokens=true to re-key)."
),
)
return None, False
auth_token = await _rotate_agent_token(ctx.session, agent)
if agent.agent_token_hash and not verify_agent_token(
auth_token,
agent.agent_token_hash,
):
if ctx.options.rotate_tokens:
auth_token = await _rotate_agent_token(ctx.session, agent)
else:
_append_sync_error(
result,
agent=agent,
board=board,
message=(
"Warning: AUTH_TOKEN in TOOLS.md does not match backend "
"token hash (agent auth may be broken)."
),
)
return auth_token, False
async def _sync_one_agent(
ctx: _SyncContext,
result: GatewayTemplatesSyncResult,
agent: Agent,
board: Board,
) -> bool:
auth_token, fatal = await _resolve_agent_auth_token(
ctx,
result,
agent,
board,
agent_gateway_id=_gateway_agent_id(agent),
)
if fatal:
return True
if not auth_token:
return False
try:
async def _do_provision() -> None:
await provision_agent(
agent,
board,
ctx.gateway,
auth_token,
ctx.options.user,
action="update",
force_bootstrap=ctx.options.force_bootstrap,
reset_session=ctx.options.reset_sessions,
)
await _with_gateway_retry(_do_provision, backoff=ctx.backoff)
result.agents_updated += 1
except TimeoutError as exc: # pragma: no cover - gateway/network dependent
result.agents_skipped += 1
_append_sync_error(result, agent=agent, board=board, message=str(exc))
return True
except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover
result.agents_skipped += 1
_append_sync_error(
result,
agent=agent,
board=board,
message=f"Failed to sync templates: {exc}",
)
return False
else:
return False
async def _sync_main_agent(
ctx: _SyncContext,
result: GatewayTemplatesSyncResult,
) -> bool:
main_agent = (
await Agent.objects.all()
.filter(col(Agent.openclaw_session_id) == ctx.gateway.main_session_key)
.first(ctx.session)
)
if main_agent is None:
_append_sync_error(
result,
message=(
"Gateway main agent record not found; "
"skipping main agent template sync."
),
)
return True
try:
main_gateway_agent_id = await _gateway_default_agent_id(
ctx.config,
fallback_session_key=ctx.gateway.main_session_key,
backoff=ctx.backoff,
)
except TimeoutError as exc:
_append_sync_error(result, agent=main_agent, message=str(exc))
return True
if not main_gateway_agent_id:
_append_sync_error(
result,
agent=main_agent,
message="Unable to resolve gateway default agent id for main agent.",
)
return True
token, fatal = await _resolve_agent_auth_token(
ctx,
result,
main_agent,
board=None,
agent_gateway_id=main_gateway_agent_id,
)
if fatal:
return True
if not token:
_append_sync_error(
result,
agent=main_agent,
message="Skipping main agent: unable to read AUTH_TOKEN from TOOLS.md.",
)
return True
stop_sync = False
try:
async def _do_provision_main() -> None:
await provision_main_agent(
main_agent,
ctx.gateway,
token,
ctx.options.user,
action="update",
force_bootstrap=ctx.options.force_bootstrap,
reset_session=ctx.options.reset_sessions,
)
await _with_gateway_retry(_do_provision_main, backoff=ctx.backoff)
except TimeoutError as exc: # pragma: no cover - gateway/network dependent
_append_sync_error(result, agent=main_agent, message=str(exc))
stop_sync = True
except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover
_append_sync_error(
result,
agent=main_agent,
message=f"Failed to sync main agent templates: {exc}",
)
else:
result.main_updated = True
return stop_sync
async def sync_gateway_templates(
session: AsyncSession,
gateway: Gateway,
options: GatewayTemplateSyncOptions,
) -> GatewayTemplatesSyncResult:
"""Synchronize AGENTS/TOOLS/etc templates to gateway-connected agents."""
result = _base_result(
gateway,
include_main=options.include_main,
reset_sessions=options.reset_sessions,
)
if not gateway.url:
result.errors.append(
GatewayTemplatesSyncError(message="Gateway URL is not configured for this gateway.")
_append_sync_error(
result,
message="Gateway URL is not configured for this gateway.",
)
return result
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
backoff = _GatewayBackoff(timeout_s=10 * 60)
# First, wait for the gateway to be reachable (e.g. while it is restarting).
try:
async def _do_ping() -> object:
return await openclaw_call("agents.list", config=client_config)
await backoff.run(_do_ping)
except TimeoutError as exc:
result.errors.append(GatewayTemplatesSyncError(message=str(exc)))
return result
except OpenClawGatewayError as exc:
result.errors.append(GatewayTemplatesSyncError(message=str(exc)))
ctx = _SyncContext(
session=session,
gateway=gateway,
config=GatewayClientConfig(url=gateway.url, token=gateway.token),
backoff=_GatewayBackoff(timeout_s=10 * 60),
options=options,
)
if not await _ping_gateway(ctx, result):
return result
boards = await Board.objects.filter_by(gateway_id=gateway.id).all(session)
boards_by_id = {board.id: board for board in boards}
if board_id is not None:
board = boards_by_id.get(board_id)
if board is None:
result.errors.append(
GatewayTemplatesSyncError(
board_id=board_id,
message="Board does not belong to this gateway.",
)
)
return result
boards_by_id = {board_id: board}
boards_by_id = _boards_by_id(boards, board_id=options.board_id)
if boards_by_id is None:
_append_sync_error(
result,
message="Board does not belong to this gateway.",
)
return result
paused_board_ids = await _paused_board_ids(session, list(boards_by_id.keys()))
if boards_by_id:
agents = await (
Agent.objects.by_field_in("board_id", list(boards_by_id.keys()))
@@ -358,251 +633,24 @@ async def sync_gateway_templates(
else:
agents = []
stop_sync = False
for agent in agents:
board = boards_by_id.get(agent.board_id) if agent.board_id is not None else None
if board is None:
result.agents_skipped += 1
result.errors.append(
GatewayTemplatesSyncError(
agent_id=agent.id,
agent_name=agent.name,
board_id=agent.board_id,
message="Skipping agent: board not found for agent.",
)
_append_sync_error(
result,
agent=agent,
message="Skipping agent: board not found for agent.",
)
continue
if board.id in paused_board_ids:
result.agents_skipped += 1
continue
stop_sync = await _sync_one_agent(ctx, result, agent, board)
if stop_sync:
break
agent_gateway_id = _gateway_agent_id(agent)
try:
auth_token = await _get_existing_auth_token(
agent_gateway_id=agent_gateway_id,
config=client_config,
backoff=backoff,
)
except TimeoutError as exc:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=agent.id,
agent_name=agent.name,
board_id=board.id,
message=str(exc),
)
)
return result
if not auth_token:
if not rotate_tokens:
result.agents_skipped += 1
result.errors.append(
GatewayTemplatesSyncError(
agent_id=agent.id,
agent_name=agent.name,
board_id=board.id,
message="Skipping agent: unable to read AUTH_TOKEN from TOOLS.md (run with rotate_tokens=true to re-key).",
)
)
continue
raw_token = generate_agent_token()
agent.agent_token_hash = hash_agent_token(raw_token)
agent.updated_at = utcnow()
session.add(agent)
await session.commit()
await session.refresh(agent)
auth_token = raw_token
if agent.agent_token_hash and not verify_agent_token(auth_token, agent.agent_token_hash):
# Do not block template sync on token drift; optionally re-key.
if rotate_tokens:
raw_token = generate_agent_token()
agent.agent_token_hash = hash_agent_token(raw_token)
agent.updated_at = utcnow()
session.add(agent)
await session.commit()
await session.refresh(agent)
auth_token = raw_token
else:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=agent.id,
agent_name=agent.name,
board_id=board.id,
message="Warning: AUTH_TOKEN in TOOLS.md does not match backend token hash (agent auth may be broken).",
)
)
try:
agent_item: Agent = agent
board_item: Board = board
auth_token_value: str = auth_token
async def _do_provision(
agent_item: Agent = agent_item,
board_item: Board = board_item,
auth_token_value: str = auth_token_value,
) -> None:
await provision_agent(
agent_item,
board_item,
gateway,
auth_token_value,
user,
action="update",
force_bootstrap=force_bootstrap,
reset_session=reset_sessions,
)
await _with_gateway_retry(_do_provision, backoff=backoff)
result.agents_updated += 1
except TimeoutError as exc: # pragma: no cover - gateway/network dependent
result.agents_skipped += 1
result.errors.append(
GatewayTemplatesSyncError(
agent_id=agent.id,
agent_name=agent.name,
board_id=board.id,
message=str(exc),
)
)
return result
except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover
result.agents_skipped += 1
result.errors.append(
GatewayTemplatesSyncError(
agent_id=agent.id,
agent_name=agent.name,
board_id=board.id,
message=f"Failed to sync templates: {exc}",
)
)
if include_main:
main_agent = (
await Agent.objects.all()
.filter(col(Agent.openclaw_session_id) == gateway.main_session_key)
.first(session)
)
if main_agent is None:
result.errors.append(
GatewayTemplatesSyncError(
message="Gateway main agent record not found; skipping main agent template sync.",
)
)
return result
try:
main_gateway_agent_id = await _gateway_default_agent_id(
client_config,
fallback_session_key=gateway.main_session_key,
backoff=backoff,
)
except TimeoutError as exc:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=main_agent.id,
agent_name=main_agent.name,
message=str(exc),
)
)
return result
if not main_gateway_agent_id:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=main_agent.id,
agent_name=main_agent.name,
message="Unable to resolve gateway default agent id for main agent.",
)
)
return result
try:
main_token = await _get_existing_auth_token(
agent_gateway_id=main_gateway_agent_id,
config=client_config,
backoff=backoff,
)
except TimeoutError as exc:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=main_agent.id,
agent_name=main_agent.name,
message=str(exc),
)
)
return result
if not main_token:
if rotate_tokens:
raw_token = generate_agent_token()
main_agent.agent_token_hash = hash_agent_token(raw_token)
main_agent.updated_at = utcnow()
session.add(main_agent)
await session.commit()
await session.refresh(main_agent)
main_token = raw_token
else:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=main_agent.id,
agent_name=main_agent.name,
message="Skipping main agent: unable to read AUTH_TOKEN from TOOLS.md.",
)
)
return result
if main_agent.agent_token_hash and not verify_agent_token(
main_token, main_agent.agent_token_hash
):
if rotate_tokens:
raw_token = generate_agent_token()
main_agent.agent_token_hash = hash_agent_token(raw_token)
main_agent.updated_at = utcnow()
session.add(main_agent)
await session.commit()
await session.refresh(main_agent)
main_token = raw_token
else:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=main_agent.id,
agent_name=main_agent.name,
message="Warning: AUTH_TOKEN in TOOLS.md does not match backend token hash (main agent auth may be broken).",
)
)
try:
async def _do_provision_main() -> None:
await provision_main_agent(
main_agent,
gateway,
main_token,
user,
action="update",
force_bootstrap=force_bootstrap,
reset_session=reset_sessions,
)
await _with_gateway_retry(_do_provision_main, backoff=backoff)
result.main_updated = True
except TimeoutError as exc: # pragma: no cover - gateway/network dependent
result.errors.append(
GatewayTemplatesSyncError(
agent_id=main_agent.id,
agent_name=main_agent.name,
message=str(exc),
)
)
return result
except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover
result.errors.append(
GatewayTemplatesSyncError(
agent_id=main_agent.id,
agent_name=main_agent.name,
message=f"Failed to sync main agent templates: {exc}",
)
)
if not stop_sync and options.include_main:
await _sync_main_agent(ctx, result)
return result