security: fix fail-open auth, streaming payload limit, and rate limiter memory leak
- agent.py: Fail closed when gateway lookup returns None instead of silently dropping the organization filter (cross-tenant board leak) - board_webhooks.py: Read request body via streaming chunks so an oversized payload is rejected before it is fully loaded into memory - rate_limit.py: Add periodic sweep of expired keys to prevent unbounded memory growth from inactive clients - test_rate_limit.py: Add test for the new sweep behavior Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
committed by
Abhimanyu Saharan
parent
858575cf6c
commit
4960d8561b
@@ -373,10 +373,14 @@ async def list_boards(
|
|||||||
# Main agents (board_id=None) must be scoped to their organization
|
# Main agents (board_id=None) must be scoped to their organization
|
||||||
# via their gateway to prevent cross-tenant board leakage.
|
# via their gateway to prevent cross-tenant board leakage.
|
||||||
gateway = await Gateway.objects.by_id(agent_ctx.agent.gateway_id).first(session)
|
gateway = await Gateway.objects.by_id(agent_ctx.agent.gateway_id).first(session)
|
||||||
if gateway is not None:
|
if gateway is None:
|
||||||
statement = statement.where(
|
raise HTTPException(
|
||||||
col(Board.organization_id) == gateway.organization_id,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Agent gateway not found; cannot determine organization scope.",
|
||||||
)
|
)
|
||||||
|
statement = statement.where(
|
||||||
|
col(Board.organization_id) == gateway.organization_id,
|
||||||
|
)
|
||||||
statement = statement.order_by(col(Board.created_at).desc())
|
statement = statement.order_by(col(Board.created_at).desc())
|
||||||
return await paginate(session, statement)
|
return await paginate(session, statement)
|
||||||
|
|
||||||
|
|||||||
@@ -501,6 +501,8 @@ async def ingest_board_webhook(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Enforce a 1 MB payload size limit to prevent memory exhaustion.
|
# Enforce a 1 MB payload size limit to prevent memory exhaustion.
|
||||||
|
# Read the body in chunks via request.stream() so an attacker cannot
|
||||||
|
# cause OOM by sending a huge body with a missing/spoofed Content-Length.
|
||||||
max_payload_bytes = 1_048_576
|
max_payload_bytes = 1_048_576
|
||||||
content_length = request.headers.get("content-length")
|
content_length = request.headers.get("content-length")
|
||||||
if content_length and int(content_length) > max_payload_bytes:
|
if content_length and int(content_length) > max_payload_bytes:
|
||||||
@@ -508,12 +510,17 @@ async def ingest_board_webhook(
|
|||||||
status_code=status.HTTP_413_CONTENT_TOO_LARGE,
|
status_code=status.HTTP_413_CONTENT_TOO_LARGE,
|
||||||
detail=f"Payload exceeds maximum size of {max_payload_bytes} bytes.",
|
detail=f"Payload exceeds maximum size of {max_payload_bytes} bytes.",
|
||||||
)
|
)
|
||||||
raw_body = await request.body()
|
chunks: list[bytes] = []
|
||||||
if len(raw_body) > max_payload_bytes:
|
total_size = 0
|
||||||
raise HTTPException(
|
async for chunk in request.stream():
|
||||||
status_code=status.HTTP_413_CONTENT_TOO_LARGE,
|
total_size += len(chunk)
|
||||||
detail=f"Payload exceeds maximum size of {max_payload_bytes} bytes.",
|
if total_size > max_payload_bytes:
|
||||||
)
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_413_CONTENT_TOO_LARGE,
|
||||||
|
detail=f"Payload exceeds maximum size of {max_payload_bytes} bytes.",
|
||||||
|
)
|
||||||
|
chunks.append(chunk)
|
||||||
|
raw_body = b"".join(chunks)
|
||||||
_verify_webhook_signature(webhook, raw_body, request)
|
_verify_webhook_signature(webhook, raw_body, request)
|
||||||
|
|
||||||
content_type = request.headers.get("content-type")
|
content_type = request.headers.get("content-type")
|
||||||
|
|||||||
@@ -11,6 +11,9 @@ import time
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
|
# Run a full sweep of all keys every 128 calls to is_allowed.
|
||||||
|
_CLEANUP_INTERVAL = 128
|
||||||
|
|
||||||
|
|
||||||
class InMemoryRateLimiter:
|
class InMemoryRateLimiter:
|
||||||
"""Token-bucket rate limiter keyed by arbitrary string (typically client IP)."""
|
"""Token-bucket rate limiter keyed by arbitrary string (typically client IP)."""
|
||||||
@@ -20,14 +23,30 @@ class InMemoryRateLimiter:
|
|||||||
self._window_seconds = window_seconds
|
self._window_seconds = window_seconds
|
||||||
self._buckets: dict[str, list[float]] = defaultdict(list)
|
self._buckets: dict[str, list[float]] = defaultdict(list)
|
||||||
self._lock = Lock()
|
self._lock = Lock()
|
||||||
|
self._call_count = 0
|
||||||
|
|
||||||
|
def _sweep_expired(self, cutoff: float) -> None:
|
||||||
|
"""Remove keys whose timestamps have all expired."""
|
||||||
|
expired_keys = [
|
||||||
|
k for k, ts_list in self._buckets.items()
|
||||||
|
if not ts_list or ts_list[-1] <= cutoff
|
||||||
|
]
|
||||||
|
for k in expired_keys:
|
||||||
|
del self._buckets[k]
|
||||||
|
|
||||||
def is_allowed(self, key: str) -> bool:
|
def is_allowed(self, key: str) -> bool:
|
||||||
"""Return True if the request should be allowed, False if rate-limited."""
|
"""Return True if the request should be allowed, False if rate-limited."""
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
cutoff = now - self._window_seconds
|
cutoff = now - self._window_seconds
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
self._call_count += 1
|
||||||
|
# Periodically sweep all keys to evict stale entries from
|
||||||
|
# clients that have stopped making requests.
|
||||||
|
if self._call_count % _CLEANUP_INTERVAL == 0:
|
||||||
|
self._sweep_expired(cutoff)
|
||||||
|
|
||||||
timestamps = self._buckets[key]
|
timestamps = self._buckets[key]
|
||||||
# Prune expired entries
|
# Prune expired entries for the current key
|
||||||
self._buckets[key] = [ts for ts in timestamps if ts > cutoff]
|
self._buckets[key] = [ts for ts in timestamps if ts > cutoff]
|
||||||
if len(self._buckets[key]) >= self._max_requests:
|
if len(self._buckets[key]) >= self._max_requests:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -43,3 +43,29 @@ def test_window_expiry_resets_limit() -> None:
|
|||||||
future = time.monotonic() + 2.0
|
future = time.monotonic() + 2.0
|
||||||
with patch("time.monotonic", return_value=future):
|
with patch("time.monotonic", return_value=future):
|
||||||
assert limiter.is_allowed("client-a") is True
|
assert limiter.is_allowed("client-a") is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_sweep_removes_expired_keys() -> None:
|
||||||
|
"""Keys whose timestamps have all expired should be evicted during periodic sweep."""
|
||||||
|
from app.core.rate_limit import _CLEANUP_INTERVAL
|
||||||
|
|
||||||
|
limiter = InMemoryRateLimiter(max_requests=100, window_seconds=1.0)
|
||||||
|
|
||||||
|
# Fill with many unique IPs
|
||||||
|
for i in range(10):
|
||||||
|
limiter.is_allowed(f"stale-{i}")
|
||||||
|
|
||||||
|
assert len(limiter._buckets) == 10
|
||||||
|
|
||||||
|
# Advance time so all timestamps expire, then trigger enough calls to
|
||||||
|
# hit the cleanup interval.
|
||||||
|
future = time.monotonic() + 2.0
|
||||||
|
with patch("time.monotonic", return_value=future):
|
||||||
|
# Drive the call count up to a multiple of _CLEANUP_INTERVAL
|
||||||
|
remaining = _CLEANUP_INTERVAL - (limiter._call_count % _CLEANUP_INTERVAL)
|
||||||
|
for i in range(remaining):
|
||||||
|
limiter.is_allowed("trigger-sweep")
|
||||||
|
|
||||||
|
# Stale keys should have been swept; only "trigger-sweep" should remain
|
||||||
|
assert "stale-0" not in limiter._buckets
|
||||||
|
assert "trigger-sweep" in limiter._buckets
|
||||||
|
|||||||
Reference in New Issue
Block a user