diff --git a/backend/app/api/board_webhooks.py b/backend/app/api/board_webhooks.py index 40c7aae7..92a96c6a 100644 --- a/backend/app/api/board_webhooks.py +++ b/backend/app/api/board_webhooks.py @@ -509,7 +509,7 @@ async def ingest_board_webhook( ) -> BoardWebhookIngestResponse: """Open inbound webhook endpoint that stores payloads and nudges the board lead.""" client_ip = get_client_ip(request) - if not webhook_ingest_limiter.is_allowed(client_ip): + if not await webhook_ingest_limiter.is_allowed(client_ip): raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS) webhook = await _require_board_webhook( session, diff --git a/backend/app/core/agent_auth.py b/backend/app/core/agent_auth.py index cb412c26..1d293f7e 100644 --- a/backend/app/core/agent_auth.py +++ b/backend/app/core/agent_auth.py @@ -115,7 +115,7 @@ async def get_agent_auth_context( ) -> AgentAuthContext: """Require and validate agent auth token from request headers.""" client_ip = get_client_ip(request) - if not agent_auth_limiter.is_allowed(client_ip): + if not await agent_auth_limiter.is_allowed(client_ip): raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS) resolved = _resolve_agent_token( agent_token, @@ -176,7 +176,7 @@ async def get_agent_auth_context_optional( # normal user Authorization headers are not throttled. if agent_token: client_ip = get_client_ip(request) - if not agent_auth_limiter.is_allowed(client_ip): + if not await agent_auth_limiter.is_allowed(client_ip): raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS) agent = await _find_agent_for_token(session, resolved) if agent is None: diff --git a/backend/app/core/rate_limit.py b/backend/app/core/rate_limit.py index 97a2079a..ad019e37 100644 --- a/backend/app/core/rate_limit.py +++ b/backend/app/core/rate_limit.py @@ -14,6 +14,7 @@ from collections import deque from threading import Lock import redis as redis_lib +import redis.asyncio as aioredis from app.core.logging import get_logger from app.core.rate_limit_backend import RateLimitBackend @@ -28,7 +29,7 @@ class RateLimiter(ABC): """Base interface for sliding-window rate limiters.""" @abstractmethod - def is_allowed(self, key: str) -> bool: + async def is_allowed(self, key: str) -> bool: """Return True if the request should be allowed, False if rate-limited.""" @@ -50,7 +51,7 @@ class InMemoryRateLimiter(RateLimiter): for k in expired_keys: del self._buckets[k] - def is_allowed(self, key: str) -> bool: + async def is_allowed(self, key: str) -> bool: """Return True if the request should be allowed, False if rate-limited.""" now = time.monotonic() cutoff = now - self._window_seconds @@ -78,9 +79,9 @@ class RedisRateLimiter(RateLimiter): """Redis-backed sliding-window rate limiter using sorted sets. Each key is stored as a Redis sorted set where members are unique - request identifiers and scores are wall-clock timestamps. A pipeline - prunes expired entries, adds the new request, counts the window, and - sets a TTL — all in a single round-trip. + request identifiers and scores are wall-clock timestamps. An async + pipeline prunes expired entries, adds the new request, counts the + window, and sets a TTL — all in a single round-trip. Fail-open: if Redis is unreachable during a request, the request is allowed and a warning is logged. @@ -97,9 +98,9 @@ class RedisRateLimiter(RateLimiter): self._namespace = namespace self._max_requests = max_requests self._window_seconds = window_seconds - self._client: redis_lib.Redis = redis_lib.Redis.from_url(redis_url) + self._client: aioredis.Redis = aioredis.from_url(redis_url) - def is_allowed(self, key: str) -> bool: + async def is_allowed(self, key: str) -> bool: """Return True if the request should be allowed, False if rate-limited.""" redis_key = f"ratelimit:{self._namespace}:{key}" now = time.time() @@ -112,7 +113,7 @@ class RedisRateLimiter(RateLimiter): pipe.zadd(redis_key, {member: now}) pipe.zcard(redis_key) pipe.expire(redis_key, int(self._window_seconds) + 1) - results = pipe.execute() + results = await pipe.execute() count: int = results[2] except Exception: logger.warning( diff --git a/backend/tests/test_rate_limit.py b/backend/tests/test_rate_limit.py index 03b0606e..534be21d 100644 --- a/backend/tests/test_rate_limit.py +++ b/backend/tests/test_rate_limit.py @@ -46,7 +46,7 @@ class _FakePipeline: self._ops.append(("expire", key, str(seconds))) return self - def execute(self) -> list[object]: + async def execute(self) -> list[object]: results: list[object] = [] for op in self._ops: cmd = op[0] @@ -86,48 +86,53 @@ class _FakeRedis: # --------------------------------------------------------------------------- -# InMemoryRateLimiter tests (unchanged from original) +# InMemoryRateLimiter tests # --------------------------------------------------------------------------- -def test_allows_requests_within_limit() -> None: +@pytest.mark.asyncio() +async def test_allows_requests_within_limit() -> None: limiter = InMemoryRateLimiter(max_requests=5, window_seconds=60.0) for _ in range(5): - assert limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is True -def test_blocks_requests_over_limit() -> None: +@pytest.mark.asyncio() +async def test_blocks_requests_over_limit() -> None: limiter = InMemoryRateLimiter(max_requests=3, window_seconds=60.0) for _ in range(3): - assert limiter.is_allowed("client-a") is True - assert limiter.is_allowed("client-a") is False - assert limiter.is_allowed("client-a") is False + assert await limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is False + assert await limiter.is_allowed("client-a") is False -def test_separate_keys_have_independent_limits() -> None: +@pytest.mark.asyncio() +async def test_separate_keys_have_independent_limits() -> None: limiter = InMemoryRateLimiter(max_requests=2, window_seconds=60.0) - assert limiter.is_allowed("client-a") is True - assert limiter.is_allowed("client-a") is True - assert limiter.is_allowed("client-a") is False + assert await limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is False # Different key still allowed - assert limiter.is_allowed("client-b") is True - assert limiter.is_allowed("client-b") is True - assert limiter.is_allowed("client-b") is False + assert await limiter.is_allowed("client-b") is True + assert await limiter.is_allowed("client-b") is True + assert await limiter.is_allowed("client-b") is False -def test_window_expiry_resets_limit() -> None: +@pytest.mark.asyncio() +async def test_window_expiry_resets_limit() -> None: limiter = InMemoryRateLimiter(max_requests=2, window_seconds=1.0) - assert limiter.is_allowed("client-a") is True - assert limiter.is_allowed("client-a") is True - assert limiter.is_allowed("client-a") is False + assert await limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is False # Simulate time passing beyond the window future = time.monotonic() + 2.0 with patch("time.monotonic", return_value=future): - assert limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is True -def test_sweep_removes_expired_keys() -> None: +@pytest.mark.asyncio() +async def test_sweep_removes_expired_keys() -> None: """Keys whose timestamps have all expired should be evicted during periodic sweep.""" from app.core.rate_limit import _CLEANUP_INTERVAL @@ -135,7 +140,7 @@ def test_sweep_removes_expired_keys() -> None: # Fill with many unique IPs for i in range(10): - limiter.is_allowed(f"stale-{i}") + await limiter.is_allowed(f"stale-{i}") assert len(limiter._buckets) == 10 @@ -146,7 +151,7 @@ def test_sweep_removes_expired_keys() -> None: # Drive the call count up to a multiple of _CLEANUP_INTERVAL remaining = _CLEANUP_INTERVAL - (limiter._call_count % _CLEANUP_INTERVAL) for i in range(remaining): - limiter.is_allowed("trigger-sweep") + await limiter.is_allowed("trigger-sweep") # Stale keys should have been swept; only "trigger-sweep" should remain assert "stale-0" not in limiter._buckets @@ -166,7 +171,7 @@ def _make_redis_limiter( window_seconds: float = 60.0, ) -> RedisRateLimiter: """Build a RedisRateLimiter wired to a _FakeRedis instance.""" - with patch("redis.Redis.from_url", return_value=fake): + with patch("redis.asyncio.from_url", return_value=fake): return RedisRateLimiter( namespace=namespace, max_requests=max_requests, @@ -175,48 +180,53 @@ def _make_redis_limiter( ) -def test_redis_allows_within_limit() -> None: +@pytest.mark.asyncio() +async def test_redis_allows_within_limit() -> None: fake = _FakeRedis() limiter = _make_redis_limiter(fake, max_requests=5) for _ in range(5): - assert limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is True -def test_redis_blocks_over_limit() -> None: +@pytest.mark.asyncio() +async def test_redis_blocks_over_limit() -> None: fake = _FakeRedis() limiter = _make_redis_limiter(fake, max_requests=3) for _ in range(3): - assert limiter.is_allowed("client-a") is True - assert limiter.is_allowed("client-a") is False - assert limiter.is_allowed("client-a") is False + assert await limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is False + assert await limiter.is_allowed("client-a") is False -def test_redis_separate_keys_independent() -> None: +@pytest.mark.asyncio() +async def test_redis_separate_keys_independent() -> None: fake = _FakeRedis() limiter = _make_redis_limiter(fake, max_requests=2) - assert limiter.is_allowed("client-a") is True - assert limiter.is_allowed("client-a") is True - assert limiter.is_allowed("client-a") is False + assert await limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is False # Different key still allowed - assert limiter.is_allowed("client-b") is True - assert limiter.is_allowed("client-b") is True - assert limiter.is_allowed("client-b") is False + assert await limiter.is_allowed("client-b") is True + assert await limiter.is_allowed("client-b") is True + assert await limiter.is_allowed("client-b") is False -def test_redis_window_expiry() -> None: +@pytest.mark.asyncio() +async def test_redis_window_expiry() -> None: fake = _FakeRedis() limiter = _make_redis_limiter(fake, max_requests=2, window_seconds=1.0) - assert limiter.is_allowed("client-a") is True - assert limiter.is_allowed("client-a") is True - assert limiter.is_allowed("client-a") is False + assert await limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is False # Simulate time passing beyond the window future = time.time() + 2.0 with patch("time.time", return_value=future): - assert limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is True -def test_redis_fail_open_on_error() -> None: +@pytest.mark.asyncio() +async def test_redis_fail_open_on_error() -> None: """When Redis is unreachable, requests should be allowed (fail-open).""" fake = _FakeRedis() limiter = _make_redis_limiter(fake, max_requests=1) @@ -234,11 +244,12 @@ def test_redis_fail_open_on_error() -> None: limiter._client.pipeline = _broken_pipeline # type: ignore[assignment] # Should still allow (fail-open) even though Redis is broken - assert limiter.is_allowed("client-a") is True - assert limiter.is_allowed("client-a") is True # would normally be blocked + assert await limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is True # would normally be blocked -def test_redis_fail_open_logs_warning() -> None: +@pytest.mark.asyncio() +async def test_redis_fail_open_logs_warning() -> None: """Verify a warning is logged when Redis is unreachable.""" fake = _FakeRedis() limiter = _make_redis_limiter(fake, max_requests=1) @@ -255,7 +266,7 @@ def test_redis_fail_open_logs_warning() -> None: limiter._client.pipeline = _broken_pipeline # type: ignore[assignment] with patch("app.core.rate_limit.logger") as mock_logger: - limiter.is_allowed("client-a") + await limiter.is_allowed("client-a") mock_logger.warning.assert_called_once() @@ -272,9 +283,11 @@ def test_factory_returns_memory_by_default(monkeypatch: pytest.MonkeyPatch) -> N def test_factory_returns_redis_when_configured(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr("app.core.config.settings.rate_limit_backend", RateLimitBackend.REDIS) - monkeypatch.setattr("app.core.config.settings.rate_limit_redis_url", "redis://localhost:6379/0") + monkeypatch.setattr( + "app.core.config.settings.rate_limit_redis_url", "redis://localhost:6379/0" + ) fake = _FakeRedis() - with patch("redis.Redis.from_url", return_value=fake): + with patch("redis.asyncio.from_url", return_value=fake): limiter = create_rate_limiter(namespace="test", max_requests=10, window_seconds=60.0) assert isinstance(limiter, RedisRateLimiter)