refactor: switch RedisRateLimiter to async redis.asyncio client
Replace sync redis.Redis with redis.asyncio to avoid blocking the event loop during rate-limit checks. Make RateLimiter.is_allowed async across both backends and update all call sites to await. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
committed by
Abhimanyu Saharan
parent
b4bbe1c657
commit
6b55b52a68
@@ -509,7 +509,7 @@ async def ingest_board_webhook(
|
||||
) -> BoardWebhookIngestResponse:
|
||||
"""Open inbound webhook endpoint that stores payloads and nudges the board lead."""
|
||||
client_ip = get_client_ip(request)
|
||||
if not webhook_ingest_limiter.is_allowed(client_ip):
|
||||
if not await webhook_ingest_limiter.is_allowed(client_ip):
|
||||
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS)
|
||||
webhook = await _require_board_webhook(
|
||||
session,
|
||||
|
||||
@@ -115,7 +115,7 @@ async def get_agent_auth_context(
|
||||
) -> AgentAuthContext:
|
||||
"""Require and validate agent auth token from request headers."""
|
||||
client_ip = get_client_ip(request)
|
||||
if not agent_auth_limiter.is_allowed(client_ip):
|
||||
if not await agent_auth_limiter.is_allowed(client_ip):
|
||||
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS)
|
||||
resolved = _resolve_agent_token(
|
||||
agent_token,
|
||||
@@ -176,7 +176,7 @@ async def get_agent_auth_context_optional(
|
||||
# normal user Authorization headers are not throttled.
|
||||
if agent_token:
|
||||
client_ip = get_client_ip(request)
|
||||
if not agent_auth_limiter.is_allowed(client_ip):
|
||||
if not await agent_auth_limiter.is_allowed(client_ip):
|
||||
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS)
|
||||
agent = await _find_agent_for_token(session, resolved)
|
||||
if agent is None:
|
||||
|
||||
@@ -14,6 +14,7 @@ from collections import deque
|
||||
from threading import Lock
|
||||
|
||||
import redis as redis_lib
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.core.rate_limit_backend import RateLimitBackend
|
||||
@@ -28,7 +29,7 @@ class RateLimiter(ABC):
|
||||
"""Base interface for sliding-window rate limiters."""
|
||||
|
||||
@abstractmethod
|
||||
def is_allowed(self, key: str) -> bool:
|
||||
async def is_allowed(self, key: str) -> bool:
|
||||
"""Return True if the request should be allowed, False if rate-limited."""
|
||||
|
||||
|
||||
@@ -50,7 +51,7 @@ class InMemoryRateLimiter(RateLimiter):
|
||||
for k in expired_keys:
|
||||
del self._buckets[k]
|
||||
|
||||
def is_allowed(self, key: str) -> bool:
|
||||
async def is_allowed(self, key: str) -> bool:
|
||||
"""Return True if the request should be allowed, False if rate-limited."""
|
||||
now = time.monotonic()
|
||||
cutoff = now - self._window_seconds
|
||||
@@ -78,9 +79,9 @@ class RedisRateLimiter(RateLimiter):
|
||||
"""Redis-backed sliding-window rate limiter using sorted sets.
|
||||
|
||||
Each key is stored as a Redis sorted set where members are unique
|
||||
request identifiers and scores are wall-clock timestamps. A pipeline
|
||||
prunes expired entries, adds the new request, counts the window, and
|
||||
sets a TTL — all in a single round-trip.
|
||||
request identifiers and scores are wall-clock timestamps. An async
|
||||
pipeline prunes expired entries, adds the new request, counts the
|
||||
window, and sets a TTL — all in a single round-trip.
|
||||
|
||||
Fail-open: if Redis is unreachable during a request, the request is
|
||||
allowed and a warning is logged.
|
||||
@@ -97,9 +98,9 @@ class RedisRateLimiter(RateLimiter):
|
||||
self._namespace = namespace
|
||||
self._max_requests = max_requests
|
||||
self._window_seconds = window_seconds
|
||||
self._client: redis_lib.Redis = redis_lib.Redis.from_url(redis_url)
|
||||
self._client: aioredis.Redis = aioredis.from_url(redis_url)
|
||||
|
||||
def is_allowed(self, key: str) -> bool:
|
||||
async def is_allowed(self, key: str) -> bool:
|
||||
"""Return True if the request should be allowed, False if rate-limited."""
|
||||
redis_key = f"ratelimit:{self._namespace}:{key}"
|
||||
now = time.time()
|
||||
@@ -112,7 +113,7 @@ class RedisRateLimiter(RateLimiter):
|
||||
pipe.zadd(redis_key, {member: now})
|
||||
pipe.zcard(redis_key)
|
||||
pipe.expire(redis_key, int(self._window_seconds) + 1)
|
||||
results = pipe.execute()
|
||||
results = await pipe.execute()
|
||||
count: int = results[2]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
|
||||
@@ -46,7 +46,7 @@ class _FakePipeline:
|
||||
self._ops.append(("expire", key, str(seconds)))
|
||||
return self
|
||||
|
||||
def execute(self) -> list[object]:
|
||||
async def execute(self) -> list[object]:
|
||||
results: list[object] = []
|
||||
for op in self._ops:
|
||||
cmd = op[0]
|
||||
@@ -86,48 +86,53 @@ class _FakeRedis:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# InMemoryRateLimiter tests (unchanged from original)
|
||||
# InMemoryRateLimiter tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_allows_requests_within_limit() -> None:
|
||||
@pytest.mark.asyncio()
|
||||
async def test_allows_requests_within_limit() -> None:
|
||||
limiter = InMemoryRateLimiter(max_requests=5, window_seconds=60.0)
|
||||
for _ in range(5):
|
||||
assert limiter.is_allowed("client-a") is True
|
||||
assert await limiter.is_allowed("client-a") is True
|
||||
|
||||
|
||||
def test_blocks_requests_over_limit() -> None:
|
||||
@pytest.mark.asyncio()
|
||||
async def test_blocks_requests_over_limit() -> None:
|
||||
limiter = InMemoryRateLimiter(max_requests=3, window_seconds=60.0)
|
||||
for _ in range(3):
|
||||
assert limiter.is_allowed("client-a") is True
|
||||
assert limiter.is_allowed("client-a") is False
|
||||
assert limiter.is_allowed("client-a") is False
|
||||
assert await limiter.is_allowed("client-a") is True
|
||||
assert await limiter.is_allowed("client-a") is False
|
||||
assert await limiter.is_allowed("client-a") is False
|
||||
|
||||
|
||||
def test_separate_keys_have_independent_limits() -> None:
|
||||
@pytest.mark.asyncio()
|
||||
async def test_separate_keys_have_independent_limits() -> None:
|
||||
limiter = InMemoryRateLimiter(max_requests=2, window_seconds=60.0)
|
||||
assert limiter.is_allowed("client-a") is True
|
||||
assert limiter.is_allowed("client-a") is True
|
||||
assert limiter.is_allowed("client-a") is False
|
||||
assert await limiter.is_allowed("client-a") is True
|
||||
assert await limiter.is_allowed("client-a") is True
|
||||
assert await limiter.is_allowed("client-a") is False
|
||||
# Different key still allowed
|
||||
assert limiter.is_allowed("client-b") is True
|
||||
assert limiter.is_allowed("client-b") is True
|
||||
assert limiter.is_allowed("client-b") is False
|
||||
assert await limiter.is_allowed("client-b") is True
|
||||
assert await limiter.is_allowed("client-b") is True
|
||||
assert await limiter.is_allowed("client-b") is False
|
||||
|
||||
|
||||
def test_window_expiry_resets_limit() -> None:
|
||||
@pytest.mark.asyncio()
|
||||
async def test_window_expiry_resets_limit() -> None:
|
||||
limiter = InMemoryRateLimiter(max_requests=2, window_seconds=1.0)
|
||||
assert limiter.is_allowed("client-a") is True
|
||||
assert limiter.is_allowed("client-a") is True
|
||||
assert limiter.is_allowed("client-a") is False
|
||||
assert await limiter.is_allowed("client-a") is True
|
||||
assert await limiter.is_allowed("client-a") is True
|
||||
assert await limiter.is_allowed("client-a") is False
|
||||
|
||||
# Simulate time passing beyond the window
|
||||
future = time.monotonic() + 2.0
|
||||
with patch("time.monotonic", return_value=future):
|
||||
assert limiter.is_allowed("client-a") is True
|
||||
assert await limiter.is_allowed("client-a") is True
|
||||
|
||||
|
||||
def test_sweep_removes_expired_keys() -> None:
|
||||
@pytest.mark.asyncio()
|
||||
async def test_sweep_removes_expired_keys() -> None:
|
||||
"""Keys whose timestamps have all expired should be evicted during periodic sweep."""
|
||||
from app.core.rate_limit import _CLEANUP_INTERVAL
|
||||
|
||||
@@ -135,7 +140,7 @@ def test_sweep_removes_expired_keys() -> None:
|
||||
|
||||
# Fill with many unique IPs
|
||||
for i in range(10):
|
||||
limiter.is_allowed(f"stale-{i}")
|
||||
await limiter.is_allowed(f"stale-{i}")
|
||||
|
||||
assert len(limiter._buckets) == 10
|
||||
|
||||
@@ -146,7 +151,7 @@ def test_sweep_removes_expired_keys() -> None:
|
||||
# Drive the call count up to a multiple of _CLEANUP_INTERVAL
|
||||
remaining = _CLEANUP_INTERVAL - (limiter._call_count % _CLEANUP_INTERVAL)
|
||||
for i in range(remaining):
|
||||
limiter.is_allowed("trigger-sweep")
|
||||
await limiter.is_allowed("trigger-sweep")
|
||||
|
||||
# Stale keys should have been swept; only "trigger-sweep" should remain
|
||||
assert "stale-0" not in limiter._buckets
|
||||
@@ -166,7 +171,7 @@ def _make_redis_limiter(
|
||||
window_seconds: float = 60.0,
|
||||
) -> RedisRateLimiter:
|
||||
"""Build a RedisRateLimiter wired to a _FakeRedis instance."""
|
||||
with patch("redis.Redis.from_url", return_value=fake):
|
||||
with patch("redis.asyncio.from_url", return_value=fake):
|
||||
return RedisRateLimiter(
|
||||
namespace=namespace,
|
||||
max_requests=max_requests,
|
||||
@@ -175,48 +180,53 @@ def _make_redis_limiter(
|
||||
)
|
||||
|
||||
|
||||
def test_redis_allows_within_limit() -> None:
|
||||
@pytest.mark.asyncio()
|
||||
async def test_redis_allows_within_limit() -> None:
|
||||
fake = _FakeRedis()
|
||||
limiter = _make_redis_limiter(fake, max_requests=5)
|
||||
for _ in range(5):
|
||||
assert limiter.is_allowed("client-a") is True
|
||||
assert await limiter.is_allowed("client-a") is True
|
||||
|
||||
|
||||
def test_redis_blocks_over_limit() -> None:
|
||||
@pytest.mark.asyncio()
|
||||
async def test_redis_blocks_over_limit() -> None:
|
||||
fake = _FakeRedis()
|
||||
limiter = _make_redis_limiter(fake, max_requests=3)
|
||||
for _ in range(3):
|
||||
assert limiter.is_allowed("client-a") is True
|
||||
assert limiter.is_allowed("client-a") is False
|
||||
assert limiter.is_allowed("client-a") is False
|
||||
assert await limiter.is_allowed("client-a") is True
|
||||
assert await limiter.is_allowed("client-a") is False
|
||||
assert await limiter.is_allowed("client-a") is False
|
||||
|
||||
|
||||
def test_redis_separate_keys_independent() -> None:
|
||||
@pytest.mark.asyncio()
|
||||
async def test_redis_separate_keys_independent() -> None:
|
||||
fake = _FakeRedis()
|
||||
limiter = _make_redis_limiter(fake, max_requests=2)
|
||||
assert limiter.is_allowed("client-a") is True
|
||||
assert limiter.is_allowed("client-a") is True
|
||||
assert limiter.is_allowed("client-a") is False
|
||||
assert await limiter.is_allowed("client-a") is True
|
||||
assert await limiter.is_allowed("client-a") is True
|
||||
assert await limiter.is_allowed("client-a") is False
|
||||
# Different key still allowed
|
||||
assert limiter.is_allowed("client-b") is True
|
||||
assert limiter.is_allowed("client-b") is True
|
||||
assert limiter.is_allowed("client-b") is False
|
||||
assert await limiter.is_allowed("client-b") is True
|
||||
assert await limiter.is_allowed("client-b") is True
|
||||
assert await limiter.is_allowed("client-b") is False
|
||||
|
||||
|
||||
def test_redis_window_expiry() -> None:
|
||||
@pytest.mark.asyncio()
|
||||
async def test_redis_window_expiry() -> None:
|
||||
fake = _FakeRedis()
|
||||
limiter = _make_redis_limiter(fake, max_requests=2, window_seconds=1.0)
|
||||
assert limiter.is_allowed("client-a") is True
|
||||
assert limiter.is_allowed("client-a") is True
|
||||
assert limiter.is_allowed("client-a") is False
|
||||
assert await limiter.is_allowed("client-a") is True
|
||||
assert await limiter.is_allowed("client-a") is True
|
||||
assert await limiter.is_allowed("client-a") is False
|
||||
|
||||
# Simulate time passing beyond the window
|
||||
future = time.time() + 2.0
|
||||
with patch("time.time", return_value=future):
|
||||
assert limiter.is_allowed("client-a") is True
|
||||
assert await limiter.is_allowed("client-a") is True
|
||||
|
||||
|
||||
def test_redis_fail_open_on_error() -> None:
|
||||
@pytest.mark.asyncio()
|
||||
async def test_redis_fail_open_on_error() -> None:
|
||||
"""When Redis is unreachable, requests should be allowed (fail-open)."""
|
||||
fake = _FakeRedis()
|
||||
limiter = _make_redis_limiter(fake, max_requests=1)
|
||||
@@ -234,11 +244,12 @@ def test_redis_fail_open_on_error() -> None:
|
||||
limiter._client.pipeline = _broken_pipeline # type: ignore[assignment]
|
||||
|
||||
# Should still allow (fail-open) even though Redis is broken
|
||||
assert limiter.is_allowed("client-a") is True
|
||||
assert limiter.is_allowed("client-a") is True # would normally be blocked
|
||||
assert await limiter.is_allowed("client-a") is True
|
||||
assert await limiter.is_allowed("client-a") is True # would normally be blocked
|
||||
|
||||
|
||||
def test_redis_fail_open_logs_warning() -> None:
|
||||
@pytest.mark.asyncio()
|
||||
async def test_redis_fail_open_logs_warning() -> None:
|
||||
"""Verify a warning is logged when Redis is unreachable."""
|
||||
fake = _FakeRedis()
|
||||
limiter = _make_redis_limiter(fake, max_requests=1)
|
||||
@@ -255,7 +266,7 @@ def test_redis_fail_open_logs_warning() -> None:
|
||||
limiter._client.pipeline = _broken_pipeline # type: ignore[assignment]
|
||||
|
||||
with patch("app.core.rate_limit.logger") as mock_logger:
|
||||
limiter.is_allowed("client-a")
|
||||
await limiter.is_allowed("client-a")
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
|
||||
@@ -272,9 +283,11 @@ def test_factory_returns_memory_by_default(monkeypatch: pytest.MonkeyPatch) -> N
|
||||
|
||||
def test_factory_returns_redis_when_configured(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr("app.core.config.settings.rate_limit_backend", RateLimitBackend.REDIS)
|
||||
monkeypatch.setattr("app.core.config.settings.rate_limit_redis_url", "redis://localhost:6379/0")
|
||||
monkeypatch.setattr(
|
||||
"app.core.config.settings.rate_limit_redis_url", "redis://localhost:6379/0"
|
||||
)
|
||||
fake = _FakeRedis()
|
||||
with patch("redis.Redis.from_url", return_value=fake):
|
||||
with patch("redis.asyncio.from_url", return_value=fake):
|
||||
limiter = create_rate_limiter(namespace="test", max_requests=10, window_seconds=60.0)
|
||||
assert isinstance(limiter, RedisRateLimiter)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user