refactor: switch RedisRateLimiter to async redis.asyncio client

Replace sync redis.Redis with redis.asyncio to avoid blocking the
event loop during rate-limit checks. Make RateLimiter.is_allowed async
across both backends and update all call sites to await.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Hugh Brown
2026-03-04 12:53:02 -07:00
committed by Abhimanyu Saharan
parent b4bbe1c657
commit 6b55b52a68
4 changed files with 74 additions and 60 deletions

View File

@@ -46,7 +46,7 @@ class _FakePipeline:
self._ops.append(("expire", key, str(seconds)))
return self
def execute(self) -> list[object]:
async def execute(self) -> list[object]:
results: list[object] = []
for op in self._ops:
cmd = op[0]
@@ -86,48 +86,53 @@ class _FakeRedis:
# ---------------------------------------------------------------------------
# InMemoryRateLimiter tests (unchanged from original)
# InMemoryRateLimiter tests
# ---------------------------------------------------------------------------
def test_allows_requests_within_limit() -> None:
@pytest.mark.asyncio()
async def test_allows_requests_within_limit() -> None:
limiter = InMemoryRateLimiter(max_requests=5, window_seconds=60.0)
for _ in range(5):
assert limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is True
def test_blocks_requests_over_limit() -> None:
@pytest.mark.asyncio()
async def test_blocks_requests_over_limit() -> None:
limiter = InMemoryRateLimiter(max_requests=3, window_seconds=60.0)
for _ in range(3):
assert limiter.is_allowed("client-a") is True
assert limiter.is_allowed("client-a") is False
assert limiter.is_allowed("client-a") is False
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is False
assert await limiter.is_allowed("client-a") is False
def test_separate_keys_have_independent_limits() -> None:
@pytest.mark.asyncio()
async def test_separate_keys_have_independent_limits() -> None:
limiter = InMemoryRateLimiter(max_requests=2, window_seconds=60.0)
assert limiter.is_allowed("client-a") is True
assert limiter.is_allowed("client-a") is True
assert limiter.is_allowed("client-a") is False
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is False
# Different key still allowed
assert limiter.is_allowed("client-b") is True
assert limiter.is_allowed("client-b") is True
assert limiter.is_allowed("client-b") is False
assert await limiter.is_allowed("client-b") is True
assert await limiter.is_allowed("client-b") is True
assert await limiter.is_allowed("client-b") is False
def test_window_expiry_resets_limit() -> None:
@pytest.mark.asyncio()
async def test_window_expiry_resets_limit() -> None:
limiter = InMemoryRateLimiter(max_requests=2, window_seconds=1.0)
assert limiter.is_allowed("client-a") is True
assert limiter.is_allowed("client-a") is True
assert limiter.is_allowed("client-a") is False
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is False
# Simulate time passing beyond the window
future = time.monotonic() + 2.0
with patch("time.monotonic", return_value=future):
assert limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is True
def test_sweep_removes_expired_keys() -> None:
@pytest.mark.asyncio()
async def test_sweep_removes_expired_keys() -> None:
"""Keys whose timestamps have all expired should be evicted during periodic sweep."""
from app.core.rate_limit import _CLEANUP_INTERVAL
@@ -135,7 +140,7 @@ def test_sweep_removes_expired_keys() -> None:
# Fill with many unique IPs
for i in range(10):
limiter.is_allowed(f"stale-{i}")
await limiter.is_allowed(f"stale-{i}")
assert len(limiter._buckets) == 10
@@ -146,7 +151,7 @@ def test_sweep_removes_expired_keys() -> None:
# Drive the call count up to a multiple of _CLEANUP_INTERVAL
remaining = _CLEANUP_INTERVAL - (limiter._call_count % _CLEANUP_INTERVAL)
for i in range(remaining):
limiter.is_allowed("trigger-sweep")
await limiter.is_allowed("trigger-sweep")
# Stale keys should have been swept; only "trigger-sweep" should remain
assert "stale-0" not in limiter._buckets
@@ -166,7 +171,7 @@ def _make_redis_limiter(
window_seconds: float = 60.0,
) -> RedisRateLimiter:
"""Build a RedisRateLimiter wired to a _FakeRedis instance."""
with patch("redis.Redis.from_url", return_value=fake):
with patch("redis.asyncio.from_url", return_value=fake):
return RedisRateLimiter(
namespace=namespace,
max_requests=max_requests,
@@ -175,48 +180,53 @@ def _make_redis_limiter(
)
def test_redis_allows_within_limit() -> None:
@pytest.mark.asyncio()
async def test_redis_allows_within_limit() -> None:
fake = _FakeRedis()
limiter = _make_redis_limiter(fake, max_requests=5)
for _ in range(5):
assert limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is True
def test_redis_blocks_over_limit() -> None:
@pytest.mark.asyncio()
async def test_redis_blocks_over_limit() -> None:
fake = _FakeRedis()
limiter = _make_redis_limiter(fake, max_requests=3)
for _ in range(3):
assert limiter.is_allowed("client-a") is True
assert limiter.is_allowed("client-a") is False
assert limiter.is_allowed("client-a") is False
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is False
assert await limiter.is_allowed("client-a") is False
def test_redis_separate_keys_independent() -> None:
@pytest.mark.asyncio()
async def test_redis_separate_keys_independent() -> None:
fake = _FakeRedis()
limiter = _make_redis_limiter(fake, max_requests=2)
assert limiter.is_allowed("client-a") is True
assert limiter.is_allowed("client-a") is True
assert limiter.is_allowed("client-a") is False
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is False
# Different key still allowed
assert limiter.is_allowed("client-b") is True
assert limiter.is_allowed("client-b") is True
assert limiter.is_allowed("client-b") is False
assert await limiter.is_allowed("client-b") is True
assert await limiter.is_allowed("client-b") is True
assert await limiter.is_allowed("client-b") is False
def test_redis_window_expiry() -> None:
@pytest.mark.asyncio()
async def test_redis_window_expiry() -> None:
fake = _FakeRedis()
limiter = _make_redis_limiter(fake, max_requests=2, window_seconds=1.0)
assert limiter.is_allowed("client-a") is True
assert limiter.is_allowed("client-a") is True
assert limiter.is_allowed("client-a") is False
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is False
# Simulate time passing beyond the window
future = time.time() + 2.0
with patch("time.time", return_value=future):
assert limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is True
def test_redis_fail_open_on_error() -> None:
@pytest.mark.asyncio()
async def test_redis_fail_open_on_error() -> None:
"""When Redis is unreachable, requests should be allowed (fail-open)."""
fake = _FakeRedis()
limiter = _make_redis_limiter(fake, max_requests=1)
@@ -234,11 +244,12 @@ def test_redis_fail_open_on_error() -> None:
limiter._client.pipeline = _broken_pipeline # type: ignore[assignment]
# Should still allow (fail-open) even though Redis is broken
assert limiter.is_allowed("client-a") is True
assert limiter.is_allowed("client-a") is True # would normally be blocked
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is True # would normally be blocked
def test_redis_fail_open_logs_warning() -> None:
@pytest.mark.asyncio()
async def test_redis_fail_open_logs_warning() -> None:
"""Verify a warning is logged when Redis is unreachable."""
fake = _FakeRedis()
limiter = _make_redis_limiter(fake, max_requests=1)
@@ -255,7 +266,7 @@ def test_redis_fail_open_logs_warning() -> None:
limiter._client.pipeline = _broken_pipeline # type: ignore[assignment]
with patch("app.core.rate_limit.logger") as mock_logger:
limiter.is_allowed("client-a")
await limiter.is_allowed("client-a")
mock_logger.warning.assert_called_once()
@@ -272,9 +283,11 @@ def test_factory_returns_memory_by_default(monkeypatch: pytest.MonkeyPatch) -> N
def test_factory_returns_redis_when_configured(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("app.core.config.settings.rate_limit_backend", RateLimitBackend.REDIS)
monkeypatch.setattr("app.core.config.settings.rate_limit_redis_url", "redis://localhost:6379/0")
monkeypatch.setattr(
"app.core.config.settings.rate_limit_redis_url", "redis://localhost:6379/0"
)
fake = _FakeRedis()
with patch("redis.Redis.from_url", return_value=fake):
with patch("redis.asyncio.from_url", return_value=fake):
limiter = create_rate_limiter(namespace="test", max_requests=10, window_seconds=60.0)
assert isinstance(limiter, RedisRateLimiter)