refactor: replace direct gateway message sending with safe dispatch helper

This commit is contained in:
Abhimanyu Saharan
2026-02-10 15:18:39 +05:30
parent 42b061f72d
commit e75b2844bb
5 changed files with 87 additions and 70 deletions

View File

@@ -34,9 +34,8 @@ from app.schemas.board_group_memory import BoardGroupMemoryCreate, BoardGroupMem
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
from app.services.mentions import extract_mentions, matches_agent_mention from app.services.mentions import extract_mentions, matches_agent_mention
from app.services.openclaw.shared import ( from app.services.openclaw.shared import (
GatewayTransportError,
optional_gateway_config_for_board, optional_gateway_config_for_board,
send_gateway_agent_message, send_gateway_agent_message_safe,
) )
from app.services.organizations import ( from app.services.organizations import (
is_org_admin, is_org_admin,
@@ -243,14 +242,13 @@ async def _notify_group_target(
f"POST {context.base_url}/api/v1/boards/{board.id}/group-memory\n" f"POST {context.base_url}/api/v1/boards/{board.id}/group-memory\n"
'Body: {"content":"...","tags":["chat"]}' 'Body: {"content":"...","tags":["chat"]}'
) )
try: error = await send_gateway_agent_message_safe(
await send_gateway_agent_message(
session_key=session_key, session_key=session_key,
config=config, config=config,
agent_name=agent.name, agent_name=agent.name,
message=message, message=message,
) )
except GatewayTransportError: if error is not None:
return return

View File

@@ -30,9 +30,8 @@ from app.schemas.pagination import DefaultLimitOffsetPage
from app.services.mentions import extract_mentions, matches_agent_mention from app.services.mentions import extract_mentions, matches_agent_mention
from app.services.openclaw.shared import ( from app.services.openclaw.shared import (
GatewayClientConfig, GatewayClientConfig,
GatewayTransportError,
optional_gateway_config_for_board, optional_gateway_config_for_board,
send_gateway_agent_message, send_gateway_agent_message_safe,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -116,15 +115,14 @@ async def _send_control_command(
continue continue
if not agent.openclaw_session_id: if not agent.openclaw_session_id:
continue continue
try: error = await send_gateway_agent_message_safe(
await send_gateway_agent_message(
session_key=agent.openclaw_session_id, session_key=agent.openclaw_session_id,
config=config, config=config,
agent_name=agent.name, agent_name=agent.name,
message=command, message=command,
deliver=True, deliver=True,
) )
except GatewayTransportError: if error is not None:
continue continue
@@ -208,14 +206,13 @@ async def _notify_chat_targets(
f"POST {base_url}/api/v1/agent/boards/{board.id}/memory\n" f"POST {base_url}/api/v1/agent/boards/{board.id}/memory\n"
'Body: {"content":"...","tags":["chat"]}' 'Body: {"content":"...","tags":["chat"]}'
) )
try: error = await send_gateway_agent_message_safe(
await send_gateway_agent_message(
session_key=agent.openclaw_session_id, session_key=agent.openclaw_session_id,
config=config, config=config,
agent_name=agent.name, agent_name=agent.name,
message=message, message=message,
) )
except GatewayTransportError: if error is not None:
continue continue

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
from collections import deque from collections import deque
from contextlib import suppress
from dataclasses import dataclass from dataclasses import dataclass
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -45,7 +44,7 @@ from app.services.openclaw.shared import (
GatewayClientConfig, GatewayClientConfig,
GatewayTransportError, GatewayTransportError,
optional_gateway_config_for_board, optional_gateway_config_for_board,
send_gateway_agent_message, send_gateway_agent_message_safe,
) )
from app.services.organizations import require_board_access from app.services.organizations import require_board_access
from app.services.task_dependencies import ( from app.services.task_dependencies import (
@@ -309,8 +308,8 @@ async def _send_lead_task_message(
session_key: str, session_key: str,
config: GatewayClientConfig, config: GatewayClientConfig,
message: str, message: str,
) -> None: ) -> GatewayTransportError | None:
await send_gateway_agent_message( return await send_gateway_agent_message_safe(
session_key=session_key, session_key=session_key,
config=config, config=config,
agent_name="Lead Agent", agent_name="Lead Agent",
@@ -325,8 +324,8 @@ async def _send_agent_task_message(
config: GatewayClientConfig, config: GatewayClientConfig,
agent_name: str, agent_name: str,
message: str, message: str,
) -> None: ) -> GatewayTransportError | None:
await send_gateway_agent_message( return await send_gateway_agent_message_safe(
session_key=session_key, session_key=session_key,
config=config, config=config,
agent_name=agent_name, agent_name=agent_name,
@@ -361,13 +360,13 @@ async def _notify_agent_on_task_assign(
+ "\n".join(details) + "\n".join(details)
+ ("\n\nTake action: open the task and begin work. " "Post updates as task comments.") + ("\n\nTake action: open the task and begin work. " "Post updates as task comments.")
) )
try: error = await _send_agent_task_message(
await _send_agent_task_message(
session_key=agent.openclaw_session_id, session_key=agent.openclaw_session_id,
config=config, config=config,
agent_name=agent.name, agent_name=agent.name,
message=message, message=message,
) )
if error is None:
record_activity( record_activity(
session, session,
event_type="task.assignee_notified", event_type="task.assignee_notified",
@@ -376,11 +375,11 @@ async def _notify_agent_on_task_assign(
task_id=task.id, task_id=task.id,
) )
await session.commit() await session.commit()
except GatewayTransportError as exc: else:
record_activity( record_activity(
session, session,
event_type="task.assignee_notify_failed", event_type="task.assignee_notify_failed",
message=f"Assignee notify failed: {exc}", message=f"Assignee notify failed: {error}",
agent_id=agent.id, agent_id=agent.id,
task_id=task.id, task_id=task.id,
) )
@@ -433,12 +432,12 @@ async def _notify_lead_on_task_create(
+ "\n".join(details) + "\n".join(details)
+ "\n\nTake action: triage, assign, or plan next steps." + "\n\nTake action: triage, assign, or plan next steps."
) )
try: error = await _send_lead_task_message(
await _send_lead_task_message(
session_key=lead.openclaw_session_id, session_key=lead.openclaw_session_id,
config=config, config=config,
message=message, message=message,
) )
if error is None:
record_activity( record_activity(
session, session,
event_type="task.lead_notified", event_type="task.lead_notified",
@@ -447,11 +446,11 @@ async def _notify_lead_on_task_create(
task_id=task.id, task_id=task.id,
) )
await session.commit() await session.commit()
except GatewayTransportError as exc: else:
record_activity( record_activity(
session, session,
event_type="task.lead_notify_failed", event_type="task.lead_notify_failed",
message=f"Lead notify failed: {exc}", message=f"Lead notify failed: {error}",
agent_id=lead.id, agent_id=lead.id,
task_id=task.id, task_id=task.id,
) )
@@ -488,12 +487,12 @@ async def _notify_lead_on_task_unassigned(
+ "\n".join(details) + "\n".join(details)
+ "\n\nTake action: assign a new owner or adjust the plan." + "\n\nTake action: assign a new owner or adjust the plan."
) )
try: error = await _send_lead_task_message(
await _send_lead_task_message(
session_key=lead.openclaw_session_id, session_key=lead.openclaw_session_id,
config=config, config=config,
message=message, message=message,
) )
if error is None:
record_activity( record_activity(
session, session,
event_type="task.lead_unassigned_notified", event_type="task.lead_unassigned_notified",
@@ -502,11 +501,11 @@ async def _notify_lead_on_task_unassigned(
task_id=task.id, task_id=task.id,
) )
await session.commit() await session.commit()
except GatewayTransportError as exc: else:
record_activity( record_activity(
session, session,
event_type="task.lead_unassigned_notify_failed", event_type="task.lead_unassigned_notify_failed",
message=f"Lead notify failed: {exc}", message=f"Lead notify failed: {error}",
agent_id=lead.id, agent_id=lead.id,
task_id=task.id, task_id=task.id,
) )
@@ -1057,7 +1056,6 @@ async def _notify_task_comment_targets(
"If you are mentioned but not assigned, reply in the task " "If you are mentioned but not assigned, reply in the task "
"thread but do not change task status." "thread but do not change task status."
) )
with suppress(GatewayTransportError):
await _send_agent_task_message( await _send_agent_task_message(
session_key=agent.openclaw_session_id, session_key=agent.openclaw_session_id,
config=config, config=config,

View File

@@ -23,6 +23,8 @@ if TYPE_CHECKING:
GatewayClientConfig = _GatewayClientConfig GatewayClientConfig = _GatewayClientConfig
# Keep integration exceptions behind the OpenClaw service boundary.
GatewayTransportError = OpenClawGatewayError
class GatewayAgentIdentity: class GatewayAgentIdentity:
@@ -121,6 +123,3 @@ def resolve_trace_id(correlation_id: str | None, *, prefix: str) -> str:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Keep integration exceptions behind the OpenClaw service boundary.
GatewayTransportError = OpenClawGatewayError

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import re
from pathlib import Path from pathlib import Path
@@ -26,3 +27,27 @@ def test_api_does_not_import_openclaw_gateway_client_directly() -> None:
"`app.services.openclaw.shared`) instead of directly from `app.api`. " "`app.services.openclaw.shared`) instead of directly from `app.api`. "
f"Violations: {', '.join(violations)}" f"Violations: {', '.join(violations)}"
) )
def test_api_uses_safe_gateway_dispatch_helper() -> None:
"""API modules should use `send_gateway_agent_message_safe`, not direct send."""
repo_root = Path(__file__).resolve().parents[2]
api_root = repo_root / "backend" / "app" / "api"
direct_send_pattern = re.compile(r"\bsend_gateway_agent_message\b")
violations: list[str] = []
for path in api_root.rglob("*.py"):
rel = path.relative_to(repo_root)
for lineno, raw_line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1):
line = raw_line.strip()
if not direct_send_pattern.search(line):
continue
if "send_gateway_agent_message_safe" in line:
continue
violations.append(f"{rel}:{lineno}")
assert not violations, (
"Use `send_gateway_agent_message_safe` from `app.services.openclaw.shared` "
"for API-level gateway notification dispatch. "
f"Violations: {', '.join(violations)}"
)