feat: add Redis-backed rate limiter with configurable backend
Add RedisRateLimiter using sorted-set sliding window alongside the existing InMemoryRateLimiter. Users choose via RATE_LIMIT_BACKEND (memory|redis) with RATE_LIMIT_REDIS_URL falling back to RQ_REDIS_URL. Redis backend validates connectivity at startup and fails open on transient errors during requests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
committed by
Abhimanyu Saharan
parent
ee825fb2d5
commit
fc9fc1661c
@@ -10,6 +10,7 @@ from pydantic import Field, model_validator
|
|||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
from app.core.auth_mode import AuthMode
|
from app.core.auth_mode import AuthMode
|
||||||
|
from app.core.rate_limit_backend import RateLimitBackend
|
||||||
|
|
||||||
BACKEND_ROOT = Path(__file__).resolve().parents[2]
|
BACKEND_ROOT = Path(__file__).resolve().parents[2]
|
||||||
DEFAULT_ENV_FILE = BACKEND_ROOT / ".env"
|
DEFAULT_ENV_FILE = BACKEND_ROOT / ".env"
|
||||||
@@ -60,6 +61,10 @@ class Settings(BaseSettings):
|
|||||||
# Webhook payload size limit in bytes (default 1 MB).
|
# Webhook payload size limit in bytes (default 1 MB).
|
||||||
webhook_max_payload_bytes: int = 1_048_576
|
webhook_max_payload_bytes: int = 1_048_576
|
||||||
|
|
||||||
|
# Rate limiting
|
||||||
|
rate_limit_backend: RateLimitBackend = RateLimitBackend.MEMORY
|
||||||
|
rate_limit_redis_url: str = ""
|
||||||
|
|
||||||
# Database lifecycle
|
# Database lifecycle
|
||||||
db_auto_migrate: bool = False
|
db_auto_migrate: bool = False
|
||||||
|
|
||||||
@@ -98,6 +103,7 @@ class Settings(BaseSettings):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"LOCAL_AUTH_TOKEN must be at least 50 characters and non-placeholder when AUTH_MODE=local.",
|
"LOCAL_AUTH_TOKEN must be at least 50 characters and non-placeholder when AUTH_MODE=local.",
|
||||||
)
|
)
|
||||||
|
|
||||||
base_url = self.base_url.strip()
|
base_url = self.base_url.strip()
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("BASE_URL must be set and non-empty.")
|
raise ValueError("BASE_URL must be set and non-empty.")
|
||||||
@@ -107,6 +113,15 @@ class Settings(BaseSettings):
|
|||||||
"BASE_URL must be an absolute http(s) URL (e.g. http://localhost:8000).",
|
"BASE_URL must be an absolute http(s) URL (e.g. http://localhost:8000).",
|
||||||
)
|
)
|
||||||
self.base_url = base_url.rstrip("/")
|
self.base_url = base_url.rstrip("/")
|
||||||
|
|
||||||
|
# Rate-limit: fall back to rq_redis_url if using redis backend
|
||||||
|
# with no explicit rate-limit URL.
|
||||||
|
if (
|
||||||
|
self.rate_limit_backend == RateLimitBackend.REDIS
|
||||||
|
and not self.rate_limit_redis_url.strip()
|
||||||
|
):
|
||||||
|
self.rate_limit_redis_url = self.rq_redis_url
|
||||||
|
|
||||||
# In dev, default to applying Alembic migrations at startup to avoid
|
# In dev, default to applying Alembic migrations at startup to avoid
|
||||||
# schema drift (e.g. missing newly-added columns).
|
# schema drift (e.g. missing newly-added columns).
|
||||||
if "db_auto_migrate" not in self.model_fields_set and self.environment == "dev":
|
if "db_auto_migrate" not in self.model_fields_set and self.environment == "dev":
|
||||||
|
|||||||
@@ -1,25 +1,38 @@
|
|||||||
"""Simple in-memory sliding-window rate limiter for abuse prevention.
|
"""Sliding-window rate limiters for abuse prevention.
|
||||||
|
|
||||||
This provides per-IP rate limiting without external dependencies.
|
Supports an in-memory backend (default, no external dependencies) and
|
||||||
Each key maintains a sliding window of recent request timestamps;
|
a Redis-backed backend for multi-process / distributed deployments.
|
||||||
a request is allowed only when the number of timestamps within the
|
Configure via RATE_LIMIT_BACKEND=memory|redis.
|
||||||
window is below the configured maximum.
|
|
||||||
|
|
||||||
For multi-process or distributed deployments, a Redis-based limiter
|
|
||||||
should be used instead.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
|
import redis as redis_lib
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.core.rate_limit_backend import RateLimitBackend
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
# Run a full sweep of all keys every 128 calls to is_allowed.
|
# Run a full sweep of all keys every 128 calls to is_allowed.
|
||||||
_CLEANUP_INTERVAL = 128
|
_CLEANUP_INTERVAL = 128
|
||||||
|
|
||||||
|
|
||||||
class InMemoryRateLimiter:
|
class RateLimiter(ABC):
|
||||||
|
"""Base interface for sliding-window rate limiters."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_allowed(self, key: str) -> bool:
|
||||||
|
"""Return True if the request should be allowed, False if rate-limited."""
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryRateLimiter(RateLimiter):
|
||||||
"""Sliding-window rate limiter keyed by arbitrary string (typically client IP)."""
|
"""Sliding-window rate limiter keyed by arbitrary string (typically client IP)."""
|
||||||
|
|
||||||
def __init__(self, *, max_requests: int, window_seconds: float) -> None:
|
def __init__(self, *, max_requests: int, window_seconds: float) -> None:
|
||||||
@@ -61,8 +74,103 @@ class InMemoryRateLimiter:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class RedisRateLimiter(RateLimiter):
|
||||||
|
"""Redis-backed sliding-window rate limiter using sorted sets.
|
||||||
|
|
||||||
|
Each key is stored as a Redis sorted set where members are unique
|
||||||
|
request identifiers and scores are wall-clock timestamps. A pipeline
|
||||||
|
prunes expired entries, adds the new request, counts the window, and
|
||||||
|
sets a TTL — all in a single round-trip.
|
||||||
|
|
||||||
|
Fail-open: if Redis is unreachable during a request, the request is
|
||||||
|
allowed and a warning is logged.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
namespace: str,
|
||||||
|
max_requests: int,
|
||||||
|
window_seconds: float,
|
||||||
|
redis_url: str,
|
||||||
|
) -> None:
|
||||||
|
self._namespace = namespace
|
||||||
|
self._max_requests = max_requests
|
||||||
|
self._window_seconds = window_seconds
|
||||||
|
self._client: redis_lib.Redis = redis_lib.Redis.from_url(redis_url)
|
||||||
|
|
||||||
|
def is_allowed(self, key: str) -> bool:
|
||||||
|
"""Return True if the request should be allowed, False if rate-limited."""
|
||||||
|
redis_key = f"ratelimit:{self._namespace}:{key}"
|
||||||
|
now = time.time()
|
||||||
|
cutoff = now - self._window_seconds
|
||||||
|
member = f"{now}:{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
pipe = self._client.pipeline(transaction=True)
|
||||||
|
pipe.zremrangebyscore(redis_key, "-inf", cutoff)
|
||||||
|
pipe.zadd(redis_key, {member: now})
|
||||||
|
pipe.zcard(redis_key)
|
||||||
|
pipe.expire(redis_key, int(self._window_seconds) + 1)
|
||||||
|
results = pipe.execute()
|
||||||
|
count: int = results[2]
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"rate_limit.redis.unavailable namespace=%s key=%s",
|
||||||
|
self._namespace,
|
||||||
|
key,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return True # fail-open
|
||||||
|
|
||||||
|
return count <= self._max_requests
|
||||||
|
|
||||||
|
|
||||||
|
def validate_rate_limit_redis(redis_url: str) -> None:
|
||||||
|
"""Verify Redis is reachable. Raises ``ConnectionError`` on failure."""
|
||||||
|
client = redis_lib.Redis.from_url(redis_url)
|
||||||
|
try:
|
||||||
|
client.ping()
|
||||||
|
except Exception as exc:
|
||||||
|
raise ConnectionError(
|
||||||
|
f"Redis rate-limit backend configured but unreachable at {redis_url}: {exc}",
|
||||||
|
) from exc
|
||||||
|
finally:
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
|
def create_rate_limiter(
|
||||||
|
*,
|
||||||
|
namespace: str,
|
||||||
|
max_requests: int,
|
||||||
|
window_seconds: float,
|
||||||
|
) -> RateLimiter:
|
||||||
|
"""Create a rate limiter based on the configured backend."""
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
if settings.rate_limit_backend == RateLimitBackend.REDIS:
|
||||||
|
return RedisRateLimiter(
|
||||||
|
namespace=namespace,
|
||||||
|
max_requests=max_requests,
|
||||||
|
window_seconds=window_seconds,
|
||||||
|
redis_url=settings.rate_limit_redis_url,
|
||||||
|
)
|
||||||
|
return InMemoryRateLimiter(
|
||||||
|
max_requests=max_requests,
|
||||||
|
window_seconds=window_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Shared limiter instances for specific endpoints.
|
# Shared limiter instances for specific endpoints.
|
||||||
# Agent auth: 20 attempts per 60 seconds per IP.
|
# Agent auth: 20 attempts per 60 seconds per IP.
|
||||||
agent_auth_limiter = InMemoryRateLimiter(max_requests=20, window_seconds=60.0)
|
agent_auth_limiter: RateLimiter = create_rate_limiter(
|
||||||
|
namespace="agent_auth",
|
||||||
|
max_requests=20,
|
||||||
|
window_seconds=60.0,
|
||||||
|
)
|
||||||
# Webhook ingest: 60 requests per 60 seconds per IP.
|
# Webhook ingest: 60 requests per 60 seconds per IP.
|
||||||
webhook_ingest_limiter = InMemoryRateLimiter(max_requests=60, window_seconds=60.0)
|
webhook_ingest_limiter: RateLimiter = create_rate_limiter(
|
||||||
|
namespace="webhook_ingest",
|
||||||
|
max_requests=60,
|
||||||
|
window_seconds=60.0,
|
||||||
|
)
|
||||||
|
|||||||
12
backend/app/core/rate_limit_backend.py
Normal file
12
backend/app/core/rate_limit_backend.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""Rate-limit backend selection enum."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitBackend(str, Enum):
|
||||||
|
"""Supported rate-limiting backends."""
|
||||||
|
|
||||||
|
MEMORY = "memory"
|
||||||
|
REDIS = "redis"
|
||||||
@@ -34,6 +34,8 @@ from app.api.users import router as users_router
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.error_handling import install_error_handling
|
from app.core.error_handling import install_error_handling
|
||||||
from app.core.logging import configure_logging, get_logger
|
from app.core.logging import configure_logging, get_logger
|
||||||
|
from app.core.rate_limit import validate_rate_limit_redis
|
||||||
|
from app.core.rate_limit_backend import RateLimitBackend
|
||||||
from app.core.security_headers import SecurityHeadersMiddleware
|
from app.core.security_headers import SecurityHeadersMiddleware
|
||||||
from app.db.session import init_db
|
from app.db.session import init_db
|
||||||
from app.schemas.health import HealthStatusResponse
|
from app.schemas.health import HealthStatusResponse
|
||||||
@@ -437,6 +439,11 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
|
|||||||
settings.db_auto_migrate,
|
settings.db_auto_migrate,
|
||||||
)
|
)
|
||||||
await init_db()
|
await init_db()
|
||||||
|
if settings.rate_limit_backend == RateLimitBackend.REDIS:
|
||||||
|
validate_rate_limit_redis(settings.rate_limit_redis_url)
|
||||||
|
logger.info("app.lifecycle.rate_limit backend=redis")
|
||||||
|
else:
|
||||||
|
logger.info("app.lifecycle.rate_limit backend=memory")
|
||||||
logger.info("app.lifecycle.started")
|
logger.info("app.lifecycle.started")
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
|
|||||||
@@ -1,11 +1,93 @@
|
|||||||
"""Tests for the in-memory rate limiter."""
|
"""Tests for rate limiters (in-memory and Redis-backed)."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from unittest.mock import patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from app.core.rate_limit import InMemoryRateLimiter
|
import pytest
|
||||||
|
|
||||||
|
from app.core.rate_limit import (
|
||||||
|
InMemoryRateLimiter,
|
||||||
|
RedisRateLimiter,
|
||||||
|
create_rate_limiter,
|
||||||
|
validate_rate_limit_redis,
|
||||||
|
)
|
||||||
|
from app.core.rate_limit_backend import RateLimitBackend
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fake Redis helpers for deterministic testing
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _FakePipeline:
|
||||||
|
"""Minimal sorted-set pipeline that executes against a _FakeRedis."""
|
||||||
|
|
||||||
|
def __init__(self, parent: _FakeRedis) -> None:
|
||||||
|
self._parent = parent
|
||||||
|
self._ops: list[tuple[str, ...]] = []
|
||||||
|
|
||||||
|
# Pipeline command stubs -- each just records intent and returns self
|
||||||
|
# so chaining works (even though our tests don't chain).
|
||||||
|
|
||||||
|
def zremrangebyscore(self, key: str, min_val: str, max_val: float) -> _FakePipeline:
|
||||||
|
self._ops.append(("zremrangebyscore", key, str(min_val), str(max_val)))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def zadd(self, key: str, mapping: dict[str, float]) -> _FakePipeline:
|
||||||
|
self._ops.append(("zadd", key, *next(iter(mapping.items()))))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def zcard(self, key: str) -> _FakePipeline:
|
||||||
|
self._ops.append(("zcard", key))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def expire(self, key: str, seconds: int) -> _FakePipeline:
|
||||||
|
self._ops.append(("expire", key, str(seconds)))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def execute(self) -> list[object]:
|
||||||
|
results: list[object] = []
|
||||||
|
for op in self._ops:
|
||||||
|
cmd = op[0]
|
||||||
|
key = op[1]
|
||||||
|
zset = self._parent._sorted_sets.setdefault(key, {})
|
||||||
|
if cmd == "zremrangebyscore":
|
||||||
|
max_score = float(op[3])
|
||||||
|
expired = [m for m, s in zset.items() if s <= max_score]
|
||||||
|
for m in expired:
|
||||||
|
del zset[m]
|
||||||
|
results.append(len(expired))
|
||||||
|
elif cmd == "zadd":
|
||||||
|
member, score = op[2], float(op[3])
|
||||||
|
zset[member] = score
|
||||||
|
results.append(1)
|
||||||
|
elif cmd == "zcard":
|
||||||
|
results.append(len(zset))
|
||||||
|
elif cmd == "expire":
|
||||||
|
results.append(True)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeRedis:
|
||||||
|
"""Minimal in-process Redis fake supporting sorted-set pipeline ops."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._sorted_sets: dict[str, dict[str, float]] = {}
|
||||||
|
|
||||||
|
def pipeline(self, *, transaction: bool = True) -> _FakePipeline:
|
||||||
|
return _FakePipeline(self)
|
||||||
|
|
||||||
|
def ping(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# InMemoryRateLimiter tests (unchanged from original)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def test_allows_requests_within_limit() -> None:
|
def test_allows_requests_within_limit() -> None:
|
||||||
@@ -69,3 +151,150 @@ def test_sweep_removes_expired_keys() -> None:
|
|||||||
# Stale keys should have been swept; only "trigger-sweep" should remain
|
# Stale keys should have been swept; only "trigger-sweep" should remain
|
||||||
assert "stale-0" not in limiter._buckets
|
assert "stale-0" not in limiter._buckets
|
||||||
assert "trigger-sweep" in limiter._buckets
|
assert "trigger-sweep" in limiter._buckets
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RedisRateLimiter tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_redis_limiter(
|
||||||
|
fake: _FakeRedis,
|
||||||
|
*,
|
||||||
|
namespace: str = "test",
|
||||||
|
max_requests: int = 5,
|
||||||
|
window_seconds: float = 60.0,
|
||||||
|
) -> RedisRateLimiter:
|
||||||
|
"""Build a RedisRateLimiter wired to a _FakeRedis instance."""
|
||||||
|
with patch("redis.Redis.from_url", return_value=fake):
|
||||||
|
return RedisRateLimiter(
|
||||||
|
namespace=namespace,
|
||||||
|
max_requests=max_requests,
|
||||||
|
window_seconds=window_seconds,
|
||||||
|
redis_url="redis://fake:6379/0",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_allows_within_limit() -> None:
|
||||||
|
fake = _FakeRedis()
|
||||||
|
limiter = _make_redis_limiter(fake, max_requests=5)
|
||||||
|
for _ in range(5):
|
||||||
|
assert limiter.is_allowed("client-a") is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_blocks_over_limit() -> None:
|
||||||
|
fake = _FakeRedis()
|
||||||
|
limiter = _make_redis_limiter(fake, max_requests=3)
|
||||||
|
for _ in range(3):
|
||||||
|
assert limiter.is_allowed("client-a") is True
|
||||||
|
assert limiter.is_allowed("client-a") is False
|
||||||
|
assert limiter.is_allowed("client-a") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_separate_keys_independent() -> None:
|
||||||
|
fake = _FakeRedis()
|
||||||
|
limiter = _make_redis_limiter(fake, max_requests=2)
|
||||||
|
assert limiter.is_allowed("client-a") is True
|
||||||
|
assert limiter.is_allowed("client-a") is True
|
||||||
|
assert limiter.is_allowed("client-a") is False
|
||||||
|
# Different key still allowed
|
||||||
|
assert limiter.is_allowed("client-b") is True
|
||||||
|
assert limiter.is_allowed("client-b") is True
|
||||||
|
assert limiter.is_allowed("client-b") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_window_expiry() -> None:
|
||||||
|
fake = _FakeRedis()
|
||||||
|
limiter = _make_redis_limiter(fake, max_requests=2, window_seconds=1.0)
|
||||||
|
assert limiter.is_allowed("client-a") is True
|
||||||
|
assert limiter.is_allowed("client-a") is True
|
||||||
|
assert limiter.is_allowed("client-a") is False
|
||||||
|
|
||||||
|
# Simulate time passing beyond the window
|
||||||
|
future = time.time() + 2.0
|
||||||
|
with patch("time.time", return_value=future):
|
||||||
|
assert limiter.is_allowed("client-a") is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_fail_open_on_error() -> None:
|
||||||
|
"""When Redis is unreachable, requests should be allowed (fail-open)."""
|
||||||
|
fake = _FakeRedis()
|
||||||
|
limiter = _make_redis_limiter(fake, max_requests=1)
|
||||||
|
|
||||||
|
# Make the pipeline raise on execute
|
||||||
|
def _broken_pipeline(*, transaction: bool = True) -> MagicMock:
|
||||||
|
pipe = MagicMock()
|
||||||
|
pipe.zremrangebyscore.return_value = pipe
|
||||||
|
pipe.zadd.return_value = pipe
|
||||||
|
pipe.zcard.return_value = pipe
|
||||||
|
pipe.expire.return_value = pipe
|
||||||
|
pipe.execute.side_effect = ConnectionError("Redis gone")
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
limiter._client.pipeline = _broken_pipeline # type: ignore[assignment]
|
||||||
|
|
||||||
|
# Should still allow (fail-open) even though Redis is broken
|
||||||
|
assert limiter.is_allowed("client-a") is True
|
||||||
|
assert limiter.is_allowed("client-a") is True # would normally be blocked
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_fail_open_logs_warning() -> None:
|
||||||
|
"""Verify a warning is logged when Redis is unreachable."""
|
||||||
|
fake = _FakeRedis()
|
||||||
|
limiter = _make_redis_limiter(fake, max_requests=1)
|
||||||
|
|
||||||
|
def _broken_pipeline(*, transaction: bool = True) -> MagicMock:
|
||||||
|
pipe = MagicMock()
|
||||||
|
pipe.zremrangebyscore.return_value = pipe
|
||||||
|
pipe.zadd.return_value = pipe
|
||||||
|
pipe.zcard.return_value = pipe
|
||||||
|
pipe.expire.return_value = pipe
|
||||||
|
pipe.execute.side_effect = ConnectionError("Redis gone")
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
limiter._client.pipeline = _broken_pipeline # type: ignore[assignment]
|
||||||
|
|
||||||
|
with patch("app.core.rate_limit.logger") as mock_logger:
|
||||||
|
limiter.is_allowed("client-a")
|
||||||
|
mock_logger.warning.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Factory tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_returns_memory_by_default(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr("app.core.config.settings.rate_limit_backend", RateLimitBackend.MEMORY)
|
||||||
|
limiter = create_rate_limiter(namespace="test", max_requests=10, window_seconds=60.0)
|
||||||
|
assert isinstance(limiter, InMemoryRateLimiter)
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_returns_redis_when_configured(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr("app.core.config.settings.rate_limit_backend", RateLimitBackend.REDIS)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.core.config.settings.rate_limit_redis_url", "redis://localhost:6379/0"
|
||||||
|
)
|
||||||
|
fake = _FakeRedis()
|
||||||
|
with patch("redis.Redis.from_url", return_value=fake):
|
||||||
|
limiter = create_rate_limiter(namespace="test", max_requests=10, window_seconds=60.0)
|
||||||
|
assert isinstance(limiter, RedisRateLimiter)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Startup validation tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_redis_succeeds_when_reachable() -> None:
|
||||||
|
fake = _FakeRedis()
|
||||||
|
with patch("redis.Redis.from_url", return_value=fake):
|
||||||
|
validate_rate_limit_redis("redis://localhost:6379/0")
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_redis_raises_on_unreachable() -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.ping.side_effect = ConnectionError("refused")
|
||||||
|
with patch("redis.Redis.from_url", return_value=mock_client):
|
||||||
|
with pytest.raises(ConnectionError, match="unreachable"):
|
||||||
|
validate_rate_limit_redis("redis://bad:6379/0")
|
||||||
|
|||||||
Reference in New Issue
Block a user