fix: share a single async Redis client per URL to avoid duplicate connection pools
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
committed by
Abhimanyu Saharan
parent
a30b94c887
commit
e053fd4a46
@@ -24,6 +24,18 @@ logger = get_logger(__name__)
|
|||||||
# Run a full sweep of all keys every 128 calls to is_allowed.
|
# Run a full sweep of all keys every 128 calls to is_allowed.
|
||||||
_CLEANUP_INTERVAL = 128
|
_CLEANUP_INTERVAL = 128
|
||||||
|
|
||||||
|
# Shared async Redis clients keyed by URL to avoid duplicate connection pools.
|
||||||
|
_async_redis_clients: dict[str, aioredis.Redis] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_async_redis(redis_url: str) -> aioredis.Redis:
|
||||||
|
"""Return a shared async Redis client for *redis_url*, creating one if needed."""
|
||||||
|
client = _async_redis_clients.get(redis_url)
|
||||||
|
if client is None:
|
||||||
|
client = aioredis.from_url(redis_url)
|
||||||
|
_async_redis_clients[redis_url] = client
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
class RateLimiter(ABC):
|
class RateLimiter(ABC):
|
||||||
"""Base interface for sliding-window rate limiters."""
|
"""Base interface for sliding-window rate limiters."""
|
||||||
@@ -96,7 +108,7 @@ class RedisRateLimiter(RateLimiter):
|
|||||||
self._namespace = namespace
|
self._namespace = namespace
|
||||||
self._max_requests = max_requests
|
self._max_requests = max_requests
|
||||||
self._window_seconds = window_seconds
|
self._window_seconds = window_seconds
|
||||||
self._client: aioredis.Redis = aioredis.from_url(redis_url)
|
self._client: aioredis.Redis = _get_async_redis(redis_url)
|
||||||
|
|
||||||
async def is_allowed(self, key: str) -> bool:
|
async def is_allowed(self, key: str) -> bool:
|
||||||
"""Return True if the request should be allowed, False if rate-limited."""
|
"""Return True if the request should be allowed, False if rate-limited."""
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ def _make_redis_limiter(
|
|||||||
window_seconds: float = 60.0,
|
window_seconds: float = 60.0,
|
||||||
) -> RedisRateLimiter:
|
) -> RedisRateLimiter:
|
||||||
"""Build a RedisRateLimiter wired to a _FakeRedis instance."""
|
"""Build a RedisRateLimiter wired to a _FakeRedis instance."""
|
||||||
with patch("redis.asyncio.from_url", return_value=fake):
|
with patch("app.core.rate_limit._get_async_redis", return_value=fake):
|
||||||
return RedisRateLimiter(
|
return RedisRateLimiter(
|
||||||
namespace=namespace,
|
namespace=namespace,
|
||||||
max_requests=max_requests,
|
max_requests=max_requests,
|
||||||
@@ -285,7 +285,7 @@ def test_factory_returns_redis_when_configured(monkeypatch: pytest.MonkeyPatch)
|
|||||||
monkeypatch.setattr("app.core.config.settings.rate_limit_backend", RateLimitBackend.REDIS)
|
monkeypatch.setattr("app.core.config.settings.rate_limit_backend", RateLimitBackend.REDIS)
|
||||||
monkeypatch.setattr("app.core.config.settings.rate_limit_redis_url", "redis://localhost:6379/0")
|
monkeypatch.setattr("app.core.config.settings.rate_limit_redis_url", "redis://localhost:6379/0")
|
||||||
fake = _FakeRedis()
|
fake = _FakeRedis()
|
||||||
with patch("redis.asyncio.from_url", return_value=fake):
|
with patch("app.core.rate_limit._get_async_redis", return_value=fake):
|
||||||
limiter = create_rate_limiter(namespace="test", max_requests=10, window_seconds=60.0)
|
limiter = create_rate_limiter(namespace="test", max_requests=10, window_seconds=60.0)
|
||||||
assert isinstance(limiter, RedisRateLimiter)
|
assert isinstance(limiter, RedisRateLimiter)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user