diff --git a/backend/app/core/rate_limit.py b/backend/app/core/rate_limit.py index 216a9d23..6277cbf0 100644 --- a/backend/app/core/rate_limit.py +++ b/backend/app/core/rate_limit.py @@ -24,6 +24,18 @@ logger = get_logger(__name__) # Run a full sweep of all keys every 128 calls to is_allowed. _CLEANUP_INTERVAL = 128 +# Shared async Redis clients keyed by URL to avoid duplicate connection pools. +_async_redis_clients: dict[str, aioredis.Redis] = {} + + +def _get_async_redis(redis_url: str) -> aioredis.Redis: + """Return a shared async Redis client for *redis_url*, creating one if needed.""" + client = _async_redis_clients.get(redis_url) + if client is None: + client = aioredis.from_url(redis_url) + _async_redis_clients[redis_url] = client + return client + class RateLimiter(ABC): """Base interface for sliding-window rate limiters.""" @@ -96,7 +108,7 @@ class RedisRateLimiter(RateLimiter): self._namespace = namespace self._max_requests = max_requests self._window_seconds = window_seconds - self._client: aioredis.Redis = aioredis.from_url(redis_url) + self._client: aioredis.Redis = _get_async_redis(redis_url) async def is_allowed(self, key: str) -> bool: """Return True if the request should be allowed, False if rate-limited.""" diff --git a/backend/tests/test_rate_limit.py b/backend/tests/test_rate_limit.py index beec1ad3..5741ee12 100644 --- a/backend/tests/test_rate_limit.py +++ b/backend/tests/test_rate_limit.py @@ -171,7 +171,7 @@ def _make_redis_limiter( window_seconds: float = 60.0, ) -> RedisRateLimiter: """Build a RedisRateLimiter wired to a _FakeRedis instance.""" - with patch("redis.asyncio.from_url", return_value=fake): + with patch("app.core.rate_limit._get_async_redis", return_value=fake): return RedisRateLimiter( namespace=namespace, max_requests=max_requests, @@ -285,7 +285,7 @@ def test_factory_returns_redis_when_configured(monkeypatch: pytest.MonkeyPatch) monkeypatch.setattr("app.core.config.settings.rate_limit_backend", RateLimitBackend.REDIS) monkeypatch.setattr("app.core.config.settings.rate_limit_redis_url", "redis://localhost:6379/0") fake = _FakeRedis() - with patch("redis.asyncio.from_url", return_value=fake): + with patch("app.core.rate_limit._get_async_redis", return_value=fake): limiter = create_rate_limiter(namespace="test", max_requests=10, window_seconds=60.0) assert isinstance(limiter, RedisRateLimiter)