feat: add disable_device_pairing option to gateway configuration

This commit is contained in:
Abhimanyu Saharan
2026-02-22 19:19:26 +05:30
parent e39b2069fb
commit 3dfb70cd90
34 changed files with 1229 additions and 178 deletions

View File

@@ -10,8 +10,8 @@ from __future__ import annotations
import asyncio
import json
from dataclasses import dataclass
from time import perf_counter
from typing import Any
from time import perf_counter, time
from typing import Any, Literal
from urllib.parse import urlencode, urlparse, urlunparse
from uuid import uuid4
@@ -19,6 +19,12 @@ import websockets
from websockets.exceptions import WebSocketException
from app.core.logging import TRACE_LEVEL, get_logger
from app.services.openclaw.device_identity import (
build_device_auth_payload,
load_or_create_device_identity,
public_key_raw_base64url_from_pem,
sign_device_payload,
)
PROTOCOL_VERSION = 3
logger = get_logger(__name__)
@@ -28,6 +34,11 @@ GATEWAY_OPERATOR_SCOPES = (
"operator.approvals",
"operator.pairing",
)
DEFAULT_GATEWAY_CLIENT_ID = "gateway-client"
DEFAULT_GATEWAY_CLIENT_MODE = "backend"
CONTROL_UI_CLIENT_ID = "openclaw-control-ui"
CONTROL_UI_CLIENT_MODE = "ui"
GatewayConnectMode = Literal["device", "control_ui"]
# NOTE: These are the base gateway methods from the OpenClaw gateway repo.
# The gateway can expose additional methods at runtime via channel plugins.
@@ -160,6 +171,7 @@ class GatewayConfig:
url: str
token: str | None = None
disable_device_pairing: bool = False
def _build_gateway_url(config: GatewayConfig) -> str:
@@ -180,6 +192,60 @@ def _redacted_url_for_log(raw_url: str) -> str:
return str(urlunparse(parsed._replace(query="", fragment="")))
def _build_control_ui_origin(gateway_url: str) -> str | None:
parsed = urlparse(gateway_url)
if not parsed.hostname:
return None
if parsed.scheme in {"ws", "http"}:
origin_scheme = "http"
elif parsed.scheme in {"wss", "https"}:
origin_scheme = "https"
else:
return None
host = parsed.hostname
if ":" in host and not host.startswith("["):
host = f"[{host}]"
if parsed.port is not None:
host = f"{host}:{parsed.port}"
return f"{origin_scheme}://{host}"
def _resolve_connect_mode(config: GatewayConfig) -> GatewayConnectMode:
return "control_ui" if config.disable_device_pairing else "device"
def _build_device_connect_payload(
*,
client_id: str,
client_mode: str,
role: str,
scopes: list[str],
auth_token: str | None,
connect_nonce: str | None,
) -> dict[str, Any]:
identity = load_or_create_device_identity()
signed_at_ms = int(time() * 1000)
payload = build_device_auth_payload(
device_id=identity.device_id,
client_id=client_id,
client_mode=client_mode,
role=role,
scopes=scopes,
signed_at_ms=signed_at_ms,
token=auth_token,
nonce=connect_nonce,
)
device_payload: dict[str, Any] = {
"id": identity.device_id,
"publicKey": public_key_raw_base64url_from_pem(identity.public_key_pem),
"signature": sign_device_payload(identity.private_key_pem, payload),
"signedAt": signed_at_ms,
}
if connect_nonce:
device_payload["nonce"] = connect_nonce
return device_payload
async def _await_response(
ws: websockets.ClientConnection,
request_id: str,
@@ -231,19 +297,36 @@ async def _send_request(
return await _await_response(ws, request_id)
def _build_connect_params(config: GatewayConfig) -> dict[str, Any]:
def _build_connect_params(
config: GatewayConfig,
*,
connect_nonce: str | None = None,
) -> dict[str, Any]:
role = "operator"
scopes = list(GATEWAY_OPERATOR_SCOPES)
connect_mode = _resolve_connect_mode(config)
use_control_ui = connect_mode == "control_ui"
params: dict[str, Any] = {
"minProtocol": PROTOCOL_VERSION,
"maxProtocol": PROTOCOL_VERSION,
"role": "operator",
"scopes": list(GATEWAY_OPERATOR_SCOPES),
"role": role,
"scopes": scopes,
"client": {
"id": "gateway-client",
"id": CONTROL_UI_CLIENT_ID if use_control_ui else DEFAULT_GATEWAY_CLIENT_ID,
"version": "1.0.0",
"platform": "web",
"mode": "ui",
"platform": "python",
"mode": CONTROL_UI_CLIENT_MODE if use_control_ui else DEFAULT_GATEWAY_CLIENT_MODE,
},
}
if not use_control_ui:
params["device"] = _build_device_connect_payload(
client_id=DEFAULT_GATEWAY_CLIENT_ID,
client_mode=DEFAULT_GATEWAY_CLIENT_MODE,
role=role,
scopes=scopes,
auth_token=config.token,
connect_nonce=connect_nonce,
)
if config.token:
params["auth"] = {"token": config.token}
return params
@@ -254,11 +337,18 @@ async def _ensure_connected(
first_message: str | bytes | None,
config: GatewayConfig,
) -> object:
connect_nonce: str | None = None
if first_message:
if isinstance(first_message, bytes):
first_message = first_message.decode("utf-8")
data = json.loads(first_message)
if data.get("type") != "event" or data.get("event") != "connect.challenge":
if data.get("type") == "event" and data.get("event") == "connect.challenge":
payload = data.get("payload")
if isinstance(payload, dict):
nonce = payload.get("nonce")
if isinstance(nonce, str) and nonce.strip():
connect_nonce = nonce.strip()
else:
logger.warning(
"gateway.rpc.connect.unexpected_first_message type=%s event=%s",
data.get("type"),
@@ -269,12 +359,52 @@ async def _ensure_connected(
"type": "req",
"id": connect_id,
"method": "connect",
"params": _build_connect_params(config),
"params": _build_connect_params(config, connect_nonce=connect_nonce),
}
await ws.send(json.dumps(response))
return await _await_response(ws, connect_id)
async def _recv_first_message_or_none(
ws: websockets.ClientConnection,
) -> str | bytes | None:
try:
return await asyncio.wait_for(ws.recv(), timeout=2)
except TimeoutError:
return None
async def _openclaw_call_once(
method: str,
params: dict[str, Any] | None,
*,
config: GatewayConfig,
gateway_url: str,
) -> object:
origin = _build_control_ui_origin(gateway_url) if config.disable_device_pairing else None
connect_kwargs: dict[str, Any] = {"ping_interval": None}
if origin is not None:
connect_kwargs["origin"] = origin
async with websockets.connect(gateway_url, **connect_kwargs) as ws:
first_message = await _recv_first_message_or_none(ws)
await _ensure_connected(ws, first_message, config)
return await _send_request(ws, method, params)
async def _openclaw_connect_metadata_once(
*,
config: GatewayConfig,
gateway_url: str,
) -> object:
origin = _build_control_ui_origin(gateway_url) if config.disable_device_pairing else None
connect_kwargs: dict[str, Any] = {"ping_interval": None}
if origin is not None:
connect_kwargs["origin"] = origin
async with websockets.connect(gateway_url, **connect_kwargs) as ws:
first_message = await _recv_first_message_or_none(ws)
return await _ensure_connected(ws, first_message, config)
async def openclaw_call(
method: str,
params: dict[str, Any] | None = None,
@@ -290,20 +420,18 @@ async def openclaw_call(
_redacted_url_for_log(gateway_url),
)
try:
async with websockets.connect(gateway_url, ping_interval=None) as ws:
first_message = None
try:
first_message = await asyncio.wait_for(ws.recv(), timeout=2)
except TimeoutError:
first_message = None
await _ensure_connected(ws, first_message, config)
payload = await _send_request(ws, method, params)
logger.debug(
"gateway.rpc.call.success method=%s duration_ms=%s",
method,
int((perf_counter() - started_at) * 1000),
)
return payload
payload = await _openclaw_call_once(
method,
params,
config=config,
gateway_url=gateway_url,
)
logger.debug(
"gateway.rpc.call.success method=%s duration_ms=%s",
method,
int((perf_counter() - started_at) * 1000),
)
return payload
except OpenClawGatewayError:
logger.warning(
"gateway.rpc.call.gateway_error method=%s duration_ms=%s",
@@ -336,18 +464,15 @@ async def openclaw_connect_metadata(*, config: GatewayConfig) -> object:
_redacted_url_for_log(gateway_url),
)
try:
async with websockets.connect(gateway_url, ping_interval=None) as ws:
first_message = None
try:
first_message = await asyncio.wait_for(ws.recv(), timeout=2)
except TimeoutError:
first_message = None
metadata = await _ensure_connected(ws, first_message, config)
logger.debug(
"gateway.rpc.connect_metadata.success duration_ms=%s",
int((perf_counter() - started_at) * 1000),
)
return metadata
metadata = await _openclaw_connect_metadata_once(
config=config,
gateway_url=gateway_url,
)
logger.debug(
"gateway.rpc.connect_metadata.success duration_ms=%s",
int((perf_counter() - started_at) * 1000),
)
return metadata
except OpenClawGatewayError:
logger.warning(
"gateway.rpc.connect_metadata.gateway_error duration_ms=%s",