feat: add Redis-backed rate limiter with configurable backend

Add RedisRateLimiter using sorted-set sliding window alongside the
existing InMemoryRateLimiter. Users choose via RATE_LIMIT_BACKEND
(memory|redis) with RATE_LIMIT_REDIS_URL falling back to RQ_REDIS_URL.
Redis backend validates connectivity at startup and fails open on
transient errors during requests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Hugh Brown
2026-03-04 02:15:14 -07:00
committed by Abhimanyu Saharan
parent ee825fb2d5
commit fc9fc1661c
5 changed files with 385 additions and 14 deletions

View File

@@ -10,6 +10,7 @@ from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from app.core.auth_mode import AuthMode from app.core.auth_mode import AuthMode
from app.core.rate_limit_backend import RateLimitBackend
BACKEND_ROOT = Path(__file__).resolve().parents[2] BACKEND_ROOT = Path(__file__).resolve().parents[2]
DEFAULT_ENV_FILE = BACKEND_ROOT / ".env" DEFAULT_ENV_FILE = BACKEND_ROOT / ".env"
@@ -60,6 +61,10 @@ class Settings(BaseSettings):
# Webhook payload size limit in bytes (default 1 MB). # Webhook payload size limit in bytes (default 1 MB).
webhook_max_payload_bytes: int = 1_048_576 webhook_max_payload_bytes: int = 1_048_576
# Rate limiting
rate_limit_backend: RateLimitBackend = RateLimitBackend.MEMORY
rate_limit_redis_url: str = ""
# Database lifecycle # Database lifecycle
db_auto_migrate: bool = False db_auto_migrate: bool = False
@@ -98,6 +103,7 @@ class Settings(BaseSettings):
raise ValueError( raise ValueError(
"LOCAL_AUTH_TOKEN must be at least 50 characters and non-placeholder when AUTH_MODE=local.", "LOCAL_AUTH_TOKEN must be at least 50 characters and non-placeholder when AUTH_MODE=local.",
) )
base_url = self.base_url.strip() base_url = self.base_url.strip()
if not base_url: if not base_url:
raise ValueError("BASE_URL must be set and non-empty.") raise ValueError("BASE_URL must be set and non-empty.")
@@ -107,6 +113,15 @@ class Settings(BaseSettings):
"BASE_URL must be an absolute http(s) URL (e.g. http://localhost:8000).", "BASE_URL must be an absolute http(s) URL (e.g. http://localhost:8000).",
) )
self.base_url = base_url.rstrip("/") self.base_url = base_url.rstrip("/")
# Rate-limit: fall back to rq_redis_url if using redis backend
# with no explicit rate-limit URL.
if (
self.rate_limit_backend == RateLimitBackend.REDIS
and not self.rate_limit_redis_url.strip()
):
self.rate_limit_redis_url = self.rq_redis_url
# In dev, default to applying Alembic migrations at startup to avoid # In dev, default to applying Alembic migrations at startup to avoid
# schema drift (e.g. missing newly-added columns). # schema drift (e.g. missing newly-added columns).
if "db_auto_migrate" not in self.model_fields_set and self.environment == "dev": if "db_auto_migrate" not in self.model_fields_set and self.environment == "dev":

View File

@@ -1,25 +1,38 @@
"""Simple in-memory sliding-window rate limiter for abuse prevention. """Sliding-window rate limiters for abuse prevention.
This provides per-IP rate limiting without external dependencies. Supports an in-memory backend (default, no external dependencies) and
Each key maintains a sliding window of recent request timestamps; a Redis-backed backend for multi-process / distributed deployments.
a request is allowed only when the number of timestamps within the Configure via RATE_LIMIT_BACKEND=memory|redis.
window is below the configured maximum.
For multi-process or distributed deployments, a Redis-based limiter
should be used instead.
""" """
from __future__ import annotations from __future__ import annotations
import time import time
import uuid
from abc import ABC, abstractmethod
from collections import deque from collections import deque
from threading import Lock from threading import Lock
import redis as redis_lib
from app.core.logging import get_logger
from app.core.rate_limit_backend import RateLimitBackend
logger = get_logger(__name__)
# Run a full sweep of all keys every 128 calls to is_allowed. # Run a full sweep of all keys every 128 calls to is_allowed.
_CLEANUP_INTERVAL = 128 _CLEANUP_INTERVAL = 128
class InMemoryRateLimiter: class RateLimiter(ABC):
"""Base interface for sliding-window rate limiters."""
@abstractmethod
def is_allowed(self, key: str) -> bool:
"""Return True if the request should be allowed, False if rate-limited."""
class InMemoryRateLimiter(RateLimiter):
"""Sliding-window rate limiter keyed by arbitrary string (typically client IP).""" """Sliding-window rate limiter keyed by arbitrary string (typically client IP)."""
def __init__(self, *, max_requests: int, window_seconds: float) -> None: def __init__(self, *, max_requests: int, window_seconds: float) -> None:
@@ -61,8 +74,103 @@ class InMemoryRateLimiter:
return True return True
class RedisRateLimiter(RateLimiter):
"""Redis-backed sliding-window rate limiter using sorted sets.
Each key is stored as a Redis sorted set where members are unique
request identifiers and scores are wall-clock timestamps. A pipeline
prunes expired entries, adds the new request, counts the window, and
sets a TTL — all in a single round-trip.
Fail-open: if Redis is unreachable during a request, the request is
allowed and a warning is logged.
"""
def __init__(
self,
*,
namespace: str,
max_requests: int,
window_seconds: float,
redis_url: str,
) -> None:
self._namespace = namespace
self._max_requests = max_requests
self._window_seconds = window_seconds
self._client: redis_lib.Redis = redis_lib.Redis.from_url(redis_url)
def is_allowed(self, key: str) -> bool:
"""Return True if the request should be allowed, False if rate-limited."""
redis_key = f"ratelimit:{self._namespace}:{key}"
now = time.time()
cutoff = now - self._window_seconds
member = f"{now}:{uuid.uuid4().hex[:8]}"
try:
pipe = self._client.pipeline(transaction=True)
pipe.zremrangebyscore(redis_key, "-inf", cutoff)
pipe.zadd(redis_key, {member: now})
pipe.zcard(redis_key)
pipe.expire(redis_key, int(self._window_seconds) + 1)
results = pipe.execute()
count: int = results[2]
except Exception:
logger.warning(
"rate_limit.redis.unavailable namespace=%s key=%s",
self._namespace,
key,
exc_info=True,
)
return True # fail-open
return count <= self._max_requests
def validate_rate_limit_redis(redis_url: str) -> None:
"""Verify Redis is reachable. Raises ``ConnectionError`` on failure."""
client = redis_lib.Redis.from_url(redis_url)
try:
client.ping()
except Exception as exc:
raise ConnectionError(
f"Redis rate-limit backend configured but unreachable at {redis_url}: {exc}",
) from exc
finally:
client.close()
def create_rate_limiter(
*,
namespace: str,
max_requests: int,
window_seconds: float,
) -> RateLimiter:
"""Create a rate limiter based on the configured backend."""
from app.core.config import settings
if settings.rate_limit_backend == RateLimitBackend.REDIS:
return RedisRateLimiter(
namespace=namespace,
max_requests=max_requests,
window_seconds=window_seconds,
redis_url=settings.rate_limit_redis_url,
)
return InMemoryRateLimiter(
max_requests=max_requests,
window_seconds=window_seconds,
)
# Shared limiter instances for specific endpoints. # Shared limiter instances for specific endpoints.
# Agent auth: 20 attempts per 60 seconds per IP. # Agent auth: 20 attempts per 60 seconds per IP.
agent_auth_limiter = InMemoryRateLimiter(max_requests=20, window_seconds=60.0) agent_auth_limiter: RateLimiter = create_rate_limiter(
namespace="agent_auth",
max_requests=20,
window_seconds=60.0,
)
# Webhook ingest: 60 requests per 60 seconds per IP. # Webhook ingest: 60 requests per 60 seconds per IP.
webhook_ingest_limiter = InMemoryRateLimiter(max_requests=60, window_seconds=60.0) webhook_ingest_limiter: RateLimiter = create_rate_limiter(
namespace="webhook_ingest",
max_requests=60,
window_seconds=60.0,
)

View File

@@ -0,0 +1,12 @@
"""Rate-limit backend selection enum."""
from __future__ import annotations
from enum import Enum
class RateLimitBackend(str, Enum):
"""Supported rate-limiting backends."""
MEMORY = "memory"
REDIS = "redis"

View File

@@ -34,6 +34,8 @@ from app.api.users import router as users_router
from app.core.config import settings from app.core.config import settings
from app.core.error_handling import install_error_handling from app.core.error_handling import install_error_handling
from app.core.logging import configure_logging, get_logger from app.core.logging import configure_logging, get_logger
from app.core.rate_limit import validate_rate_limit_redis
from app.core.rate_limit_backend import RateLimitBackend
from app.core.security_headers import SecurityHeadersMiddleware from app.core.security_headers import SecurityHeadersMiddleware
from app.db.session import init_db from app.db.session import init_db
from app.schemas.health import HealthStatusResponse from app.schemas.health import HealthStatusResponse
@@ -437,6 +439,11 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
settings.db_auto_migrate, settings.db_auto_migrate,
) )
await init_db() await init_db()
if settings.rate_limit_backend == RateLimitBackend.REDIS:
validate_rate_limit_redis(settings.rate_limit_redis_url)
logger.info("app.lifecycle.rate_limit backend=redis")
else:
logger.info("app.lifecycle.rate_limit backend=memory")
logger.info("app.lifecycle.started") logger.info("app.lifecycle.started")
try: try:
yield yield

View File

@@ -1,11 +1,93 @@
"""Tests for the in-memory rate limiter.""" """Tests for rate limiters (in-memory and Redis-backed)."""
from __future__ import annotations from __future__ import annotations
import time import time
from unittest.mock import patch from unittest.mock import MagicMock, patch
from app.core.rate_limit import InMemoryRateLimiter import pytest
from app.core.rate_limit import (
InMemoryRateLimiter,
RedisRateLimiter,
create_rate_limiter,
validate_rate_limit_redis,
)
from app.core.rate_limit_backend import RateLimitBackend
# ---------------------------------------------------------------------------
# Fake Redis helpers for deterministic testing
# ---------------------------------------------------------------------------
class _FakePipeline:
"""Minimal sorted-set pipeline that executes against a _FakeRedis."""
def __init__(self, parent: _FakeRedis) -> None:
self._parent = parent
self._ops: list[tuple[str, ...]] = []
# Pipeline command stubs -- each just records intent and returns self
# so chaining works (even though our tests don't chain).
def zremrangebyscore(self, key: str, min_val: str, max_val: float) -> _FakePipeline:
self._ops.append(("zremrangebyscore", key, str(min_val), str(max_val)))
return self
def zadd(self, key: str, mapping: dict[str, float]) -> _FakePipeline:
self._ops.append(("zadd", key, *next(iter(mapping.items()))))
return self
def zcard(self, key: str) -> _FakePipeline:
self._ops.append(("zcard", key))
return self
def expire(self, key: str, seconds: int) -> _FakePipeline:
self._ops.append(("expire", key, str(seconds)))
return self
def execute(self) -> list[object]:
results: list[object] = []
for op in self._ops:
cmd = op[0]
key = op[1]
zset = self._parent._sorted_sets.setdefault(key, {})
if cmd == "zremrangebyscore":
max_score = float(op[3])
expired = [m for m, s in zset.items() if s <= max_score]
for m in expired:
del zset[m]
results.append(len(expired))
elif cmd == "zadd":
member, score = op[2], float(op[3])
zset[member] = score
results.append(1)
elif cmd == "zcard":
results.append(len(zset))
elif cmd == "expire":
results.append(True)
return results
class _FakeRedis:
"""Minimal in-process Redis fake supporting sorted-set pipeline ops."""
def __init__(self) -> None:
self._sorted_sets: dict[str, dict[str, float]] = {}
def pipeline(self, *, transaction: bool = True) -> _FakePipeline:
return _FakePipeline(self)
def ping(self) -> bool:
return True
def close(self) -> None:
pass
# ---------------------------------------------------------------------------
# InMemoryRateLimiter tests (unchanged from original)
# ---------------------------------------------------------------------------
def test_allows_requests_within_limit() -> None: def test_allows_requests_within_limit() -> None:
@@ -69,3 +151,150 @@ def test_sweep_removes_expired_keys() -> None:
# Stale keys should have been swept; only "trigger-sweep" should remain # Stale keys should have been swept; only "trigger-sweep" should remain
assert "stale-0" not in limiter._buckets assert "stale-0" not in limiter._buckets
assert "trigger-sweep" in limiter._buckets assert "trigger-sweep" in limiter._buckets
# ---------------------------------------------------------------------------
# RedisRateLimiter tests
# ---------------------------------------------------------------------------
def _make_redis_limiter(
fake: _FakeRedis,
*,
namespace: str = "test",
max_requests: int = 5,
window_seconds: float = 60.0,
) -> RedisRateLimiter:
"""Build a RedisRateLimiter wired to a _FakeRedis instance."""
with patch("redis.Redis.from_url", return_value=fake):
return RedisRateLimiter(
namespace=namespace,
max_requests=max_requests,
window_seconds=window_seconds,
redis_url="redis://fake:6379/0",
)
def test_redis_allows_within_limit() -> None:
fake = _FakeRedis()
limiter = _make_redis_limiter(fake, max_requests=5)
for _ in range(5):
assert limiter.is_allowed("client-a") is True
def test_redis_blocks_over_limit() -> None:
fake = _FakeRedis()
limiter = _make_redis_limiter(fake, max_requests=3)
for _ in range(3):
assert limiter.is_allowed("client-a") is True
assert limiter.is_allowed("client-a") is False
assert limiter.is_allowed("client-a") is False
def test_redis_separate_keys_independent() -> None:
fake = _FakeRedis()
limiter = _make_redis_limiter(fake, max_requests=2)
assert limiter.is_allowed("client-a") is True
assert limiter.is_allowed("client-a") is True
assert limiter.is_allowed("client-a") is False
# Different key still allowed
assert limiter.is_allowed("client-b") is True
assert limiter.is_allowed("client-b") is True
assert limiter.is_allowed("client-b") is False
def test_redis_window_expiry() -> None:
fake = _FakeRedis()
limiter = _make_redis_limiter(fake, max_requests=2, window_seconds=1.0)
assert limiter.is_allowed("client-a") is True
assert limiter.is_allowed("client-a") is True
assert limiter.is_allowed("client-a") is False
# Simulate time passing beyond the window
future = time.time() + 2.0
with patch("time.time", return_value=future):
assert limiter.is_allowed("client-a") is True
def test_redis_fail_open_on_error() -> None:
"""When Redis is unreachable, requests should be allowed (fail-open)."""
fake = _FakeRedis()
limiter = _make_redis_limiter(fake, max_requests=1)
# Make the pipeline raise on execute
def _broken_pipeline(*, transaction: bool = True) -> MagicMock:
pipe = MagicMock()
pipe.zremrangebyscore.return_value = pipe
pipe.zadd.return_value = pipe
pipe.zcard.return_value = pipe
pipe.expire.return_value = pipe
pipe.execute.side_effect = ConnectionError("Redis gone")
return pipe
limiter._client.pipeline = _broken_pipeline # type: ignore[assignment]
# Should still allow (fail-open) even though Redis is broken
assert limiter.is_allowed("client-a") is True
assert limiter.is_allowed("client-a") is True # would normally be blocked
def test_redis_fail_open_logs_warning() -> None:
"""Verify a warning is logged when Redis is unreachable."""
fake = _FakeRedis()
limiter = _make_redis_limiter(fake, max_requests=1)
def _broken_pipeline(*, transaction: bool = True) -> MagicMock:
pipe = MagicMock()
pipe.zremrangebyscore.return_value = pipe
pipe.zadd.return_value = pipe
pipe.zcard.return_value = pipe
pipe.expire.return_value = pipe
pipe.execute.side_effect = ConnectionError("Redis gone")
return pipe
limiter._client.pipeline = _broken_pipeline # type: ignore[assignment]
with patch("app.core.rate_limit.logger") as mock_logger:
limiter.is_allowed("client-a")
mock_logger.warning.assert_called_once()
# ---------------------------------------------------------------------------
# Factory tests
# ---------------------------------------------------------------------------
def test_factory_returns_memory_by_default(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("app.core.config.settings.rate_limit_backend", RateLimitBackend.MEMORY)
limiter = create_rate_limiter(namespace="test", max_requests=10, window_seconds=60.0)
assert isinstance(limiter, InMemoryRateLimiter)
def test_factory_returns_redis_when_configured(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("app.core.config.settings.rate_limit_backend", RateLimitBackend.REDIS)
monkeypatch.setattr(
"app.core.config.settings.rate_limit_redis_url", "redis://localhost:6379/0"
)
fake = _FakeRedis()
with patch("redis.Redis.from_url", return_value=fake):
limiter = create_rate_limiter(namespace="test", max_requests=10, window_seconds=60.0)
assert isinstance(limiter, RedisRateLimiter)
# ---------------------------------------------------------------------------
# Startup validation tests
# ---------------------------------------------------------------------------
def test_validate_redis_succeeds_when_reachable() -> None:
fake = _FakeRedis()
with patch("redis.Redis.from_url", return_value=fake):
validate_rate_limit_redis("redis://localhost:6379/0")
def test_validate_redis_raises_on_unreachable() -> None:
mock_client = MagicMock()
mock_client.ping.side_effect = ConnectionError("refused")
with patch("redis.Redis.from_url", return_value=mock_client):
with pytest.raises(ConnectionError, match="unreachable"):
validate_rate_limit_redis("redis://bad:6379/0")