fix: share a single async Redis client per URL to avoid duplicate connection pools

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Hugh Brown
2026-03-04 14:13:45 -07:00
committed by Abhimanyu Saharan
parent a30b94c887
commit e053fd4a46
2 changed files with 15 additions and 3 deletions

View File

@@ -24,6 +24,18 @@ logger = get_logger(__name__)
# Run a full sweep of all keys every 128 calls to is_allowed. # Run a full sweep of all keys every 128 calls to is_allowed.
_CLEANUP_INTERVAL = 128 _CLEANUP_INTERVAL = 128
# Shared async Redis clients keyed by URL to avoid duplicate connection pools.
_async_redis_clients: dict[str, aioredis.Redis] = {}
def _get_async_redis(redis_url: str) -> aioredis.Redis:
"""Return a shared async Redis client for *redis_url*, creating one if needed."""
client = _async_redis_clients.get(redis_url)
if client is None:
client = aioredis.from_url(redis_url)
_async_redis_clients[redis_url] = client
return client
class RateLimiter(ABC): class RateLimiter(ABC):
"""Base interface for sliding-window rate limiters.""" """Base interface for sliding-window rate limiters."""
@@ -96,7 +108,7 @@ class RedisRateLimiter(RateLimiter):
self._namespace = namespace self._namespace = namespace
self._max_requests = max_requests self._max_requests = max_requests
self._window_seconds = window_seconds self._window_seconds = window_seconds
self._client: aioredis.Redis = aioredis.from_url(redis_url) self._client: aioredis.Redis = _get_async_redis(redis_url)
async def is_allowed(self, key: str) -> bool: async def is_allowed(self, key: str) -> bool:
"""Return True if the request should be allowed, False if rate-limited.""" """Return True if the request should be allowed, False if rate-limited."""

View File

@@ -171,7 +171,7 @@ def _make_redis_limiter(
window_seconds: float = 60.0, window_seconds: float = 60.0,
) -> RedisRateLimiter: ) -> RedisRateLimiter:
"""Build a RedisRateLimiter wired to a _FakeRedis instance.""" """Build a RedisRateLimiter wired to a _FakeRedis instance."""
with patch("redis.asyncio.from_url", return_value=fake): with patch("app.core.rate_limit._get_async_redis", return_value=fake):
return RedisRateLimiter( return RedisRateLimiter(
namespace=namespace, namespace=namespace,
max_requests=max_requests, max_requests=max_requests,
@@ -285,7 +285,7 @@ def test_factory_returns_redis_when_configured(monkeypatch: pytest.MonkeyPatch)
monkeypatch.setattr("app.core.config.settings.rate_limit_backend", RateLimitBackend.REDIS) monkeypatch.setattr("app.core.config.settings.rate_limit_backend", RateLimitBackend.REDIS)
monkeypatch.setattr("app.core.config.settings.rate_limit_redis_url", "redis://localhost:6379/0") monkeypatch.setattr("app.core.config.settings.rate_limit_redis_url", "redis://localhost:6379/0")
fake = _FakeRedis() fake = _FakeRedis()
with patch("redis.asyncio.from_url", return_value=fake): with patch("app.core.rate_limit._get_async_redis", return_value=fake):
limiter = create_rate_limiter(namespace="test", max_requests=10, window_seconds=60.0) limiter = create_rate_limiter(namespace="test", max_requests=10, window_seconds=60.0)
assert isinstance(limiter, RedisRateLimiter) assert isinstance(limiter, RedisRateLimiter)