diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 3e9f6c6e..6809b44c 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -10,6 +10,7 @@ from pydantic import Field, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict from app.core.auth_mode import AuthMode +from app.core.rate_limit_backend import RateLimitBackend BACKEND_ROOT = Path(__file__).resolve().parents[2] DEFAULT_ENV_FILE = BACKEND_ROOT / ".env" @@ -60,6 +61,10 @@ class Settings(BaseSettings): # Webhook payload size limit in bytes (default 1 MB). webhook_max_payload_bytes: int = 1_048_576 + # Rate limiting + rate_limit_backend: RateLimitBackend = RateLimitBackend.MEMORY + rate_limit_redis_url: str = "" + # Database lifecycle db_auto_migrate: bool = False @@ -98,6 +103,7 @@ class Settings(BaseSettings): raise ValueError( "LOCAL_AUTH_TOKEN must be at least 50 characters and non-placeholder when AUTH_MODE=local.", ) + base_url = self.base_url.strip() if not base_url: raise ValueError("BASE_URL must be set and non-empty.") @@ -107,6 +113,15 @@ class Settings(BaseSettings): "BASE_URL must be an absolute http(s) URL (e.g. http://localhost:8000).", ) self.base_url = base_url.rstrip("/") + + # Rate-limit: fall back to rq_redis_url if using redis backend + # with no explicit rate-limit URL. + if ( + self.rate_limit_backend == RateLimitBackend.REDIS + and not self.rate_limit_redis_url.strip() + ): + self.rate_limit_redis_url = self.rq_redis_url + # In dev, default to applying Alembic migrations at startup to avoid # schema drift (e.g. missing newly-added columns). if "db_auto_migrate" not in self.model_fields_set and self.environment == "dev": diff --git a/backend/app/core/rate_limit.py b/backend/app/core/rate_limit.py index b1f2abaf..97a2079a 100644 --- a/backend/app/core/rate_limit.py +++ b/backend/app/core/rate_limit.py @@ -1,25 +1,38 @@ -"""Simple in-memory sliding-window rate limiter for abuse prevention. +"""Sliding-window rate limiters for abuse prevention. -This provides per-IP rate limiting without external dependencies. -Each key maintains a sliding window of recent request timestamps; -a request is allowed only when the number of timestamps within the -window is below the configured maximum. - -For multi-process or distributed deployments, a Redis-based limiter -should be used instead. +Supports an in-memory backend (default, no external dependencies) and +a Redis-backed backend for multi-process / distributed deployments. +Configure via RATE_LIMIT_BACKEND=memory|redis. """ from __future__ import annotations import time +import uuid +from abc import ABC, abstractmethod from collections import deque from threading import Lock +import redis as redis_lib + +from app.core.logging import get_logger +from app.core.rate_limit_backend import RateLimitBackend + +logger = get_logger(__name__) + # Run a full sweep of all keys every 128 calls to is_allowed. _CLEANUP_INTERVAL = 128 -class InMemoryRateLimiter: +class RateLimiter(ABC): + """Base interface for sliding-window rate limiters.""" + + @abstractmethod + def is_allowed(self, key: str) -> bool: + """Return True if the request should be allowed, False if rate-limited.""" + + +class InMemoryRateLimiter(RateLimiter): """Sliding-window rate limiter keyed by arbitrary string (typically client IP).""" def __init__(self, *, max_requests: int, window_seconds: float) -> None: @@ -61,8 +74,103 @@ class InMemoryRateLimiter: return True +class RedisRateLimiter(RateLimiter): + """Redis-backed sliding-window rate limiter using sorted sets. + + Each key is stored as a Redis sorted set where members are unique + request identifiers and scores are wall-clock timestamps. A pipeline + prunes expired entries, adds the new request, counts the window, and + sets a TTL — all in a single round-trip. + + Fail-open: if Redis is unreachable during a request, the request is + allowed and a warning is logged. + """ + + def __init__( + self, + *, + namespace: str, + max_requests: int, + window_seconds: float, + redis_url: str, + ) -> None: + self._namespace = namespace + self._max_requests = max_requests + self._window_seconds = window_seconds + self._client: redis_lib.Redis = redis_lib.Redis.from_url(redis_url) + + def is_allowed(self, key: str) -> bool: + """Return True if the request should be allowed, False if rate-limited.""" + redis_key = f"ratelimit:{self._namespace}:{key}" + now = time.time() + cutoff = now - self._window_seconds + member = f"{now}:{uuid.uuid4().hex[:8]}" + + try: + pipe = self._client.pipeline(transaction=True) + pipe.zremrangebyscore(redis_key, "-inf", cutoff) + pipe.zadd(redis_key, {member: now}) + pipe.zcard(redis_key) + pipe.expire(redis_key, int(self._window_seconds) + 1) + results = pipe.execute() + count: int = results[2] + except Exception: + logger.warning( + "rate_limit.redis.unavailable namespace=%s key=%s", + self._namespace, + key, + exc_info=True, + ) + return True # fail-open + + return count <= self._max_requests + + +def validate_rate_limit_redis(redis_url: str) -> None: + """Verify Redis is reachable. Raises ``ConnectionError`` on failure.""" + client = redis_lib.Redis.from_url(redis_url) + try: + client.ping() + except Exception as exc: + raise ConnectionError( + f"Redis rate-limit backend configured but unreachable at {redis_url}: {exc}", + ) from exc + finally: + client.close() + + +def create_rate_limiter( + *, + namespace: str, + max_requests: int, + window_seconds: float, +) -> RateLimiter: + """Create a rate limiter based on the configured backend.""" + from app.core.config import settings + + if settings.rate_limit_backend == RateLimitBackend.REDIS: + return RedisRateLimiter( + namespace=namespace, + max_requests=max_requests, + window_seconds=window_seconds, + redis_url=settings.rate_limit_redis_url, + ) + return InMemoryRateLimiter( + max_requests=max_requests, + window_seconds=window_seconds, + ) + + # Shared limiter instances for specific endpoints. # Agent auth: 20 attempts per 60 seconds per IP. -agent_auth_limiter = InMemoryRateLimiter(max_requests=20, window_seconds=60.0) +agent_auth_limiter: RateLimiter = create_rate_limiter( + namespace="agent_auth", + max_requests=20, + window_seconds=60.0, +) # Webhook ingest: 60 requests per 60 seconds per IP. -webhook_ingest_limiter = InMemoryRateLimiter(max_requests=60, window_seconds=60.0) +webhook_ingest_limiter: RateLimiter = create_rate_limiter( + namespace="webhook_ingest", + max_requests=60, + window_seconds=60.0, +) diff --git a/backend/app/core/rate_limit_backend.py b/backend/app/core/rate_limit_backend.py new file mode 100644 index 00000000..b85588c1 --- /dev/null +++ b/backend/app/core/rate_limit_backend.py @@ -0,0 +1,12 @@ +"""Rate-limit backend selection enum.""" + +from __future__ import annotations + +from enum import Enum + + +class RateLimitBackend(str, Enum): + """Supported rate-limiting backends.""" + + MEMORY = "memory" + REDIS = "redis" diff --git a/backend/app/main.py b/backend/app/main.py index 4761c24f..857c29c8 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -34,6 +34,8 @@ from app.api.users import router as users_router from app.core.config import settings from app.core.error_handling import install_error_handling from app.core.logging import configure_logging, get_logger +from app.core.rate_limit import validate_rate_limit_redis +from app.core.rate_limit_backend import RateLimitBackend from app.core.security_headers import SecurityHeadersMiddleware from app.db.session import init_db from app.schemas.health import HealthStatusResponse @@ -437,6 +439,11 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]: settings.db_auto_migrate, ) await init_db() + if settings.rate_limit_backend == RateLimitBackend.REDIS: + validate_rate_limit_redis(settings.rate_limit_redis_url) + logger.info("app.lifecycle.rate_limit backend=redis") + else: + logger.info("app.lifecycle.rate_limit backend=memory") logger.info("app.lifecycle.started") try: yield diff --git a/backend/tests/test_rate_limit.py b/backend/tests/test_rate_limit.py index 66fe00a2..cc5711f3 100644 --- a/backend/tests/test_rate_limit.py +++ b/backend/tests/test_rate_limit.py @@ -1,11 +1,93 @@ -"""Tests for the in-memory rate limiter.""" +"""Tests for rate limiters (in-memory and Redis-backed).""" from __future__ import annotations import time -from unittest.mock import patch +from unittest.mock import MagicMock, patch -from app.core.rate_limit import InMemoryRateLimiter +import pytest + +from app.core.rate_limit import ( + InMemoryRateLimiter, + RedisRateLimiter, + create_rate_limiter, + validate_rate_limit_redis, +) +from app.core.rate_limit_backend import RateLimitBackend + +# --------------------------------------------------------------------------- +# Fake Redis helpers for deterministic testing +# --------------------------------------------------------------------------- + + +class _FakePipeline: + """Minimal sorted-set pipeline that executes against a _FakeRedis.""" + + def __init__(self, parent: _FakeRedis) -> None: + self._parent = parent + self._ops: list[tuple[str, ...]] = [] + + # Pipeline command stubs -- each just records intent and returns self + # so chaining works (even though our tests don't chain). + + def zremrangebyscore(self, key: str, min_val: str, max_val: float) -> _FakePipeline: + self._ops.append(("zremrangebyscore", key, str(min_val), str(max_val))) + return self + + def zadd(self, key: str, mapping: dict[str, float]) -> _FakePipeline: + self._ops.append(("zadd", key, *next(iter(mapping.items())))) + return self + + def zcard(self, key: str) -> _FakePipeline: + self._ops.append(("zcard", key)) + return self + + def expire(self, key: str, seconds: int) -> _FakePipeline: + self._ops.append(("expire", key, str(seconds))) + return self + + def execute(self) -> list[object]: + results: list[object] = [] + for op in self._ops: + cmd = op[0] + key = op[1] + zset = self._parent._sorted_sets.setdefault(key, {}) + if cmd == "zremrangebyscore": + max_score = float(op[3]) + expired = [m for m, s in zset.items() if s <= max_score] + for m in expired: + del zset[m] + results.append(len(expired)) + elif cmd == "zadd": + member, score = op[2], float(op[3]) + zset[member] = score + results.append(1) + elif cmd == "zcard": + results.append(len(zset)) + elif cmd == "expire": + results.append(True) + return results + + +class _FakeRedis: + """Minimal in-process Redis fake supporting sorted-set pipeline ops.""" + + def __init__(self) -> None: + self._sorted_sets: dict[str, dict[str, float]] = {} + + def pipeline(self, *, transaction: bool = True) -> _FakePipeline: + return _FakePipeline(self) + + def ping(self) -> bool: + return True + + def close(self) -> None: + pass + + +# --------------------------------------------------------------------------- +# InMemoryRateLimiter tests (unchanged from original) +# --------------------------------------------------------------------------- def test_allows_requests_within_limit() -> None: @@ -69,3 +151,150 @@ def test_sweep_removes_expired_keys() -> None: # Stale keys should have been swept; only "trigger-sweep" should remain assert "stale-0" not in limiter._buckets assert "trigger-sweep" in limiter._buckets + + +# --------------------------------------------------------------------------- +# RedisRateLimiter tests +# --------------------------------------------------------------------------- + + +def _make_redis_limiter( + fake: _FakeRedis, + *, + namespace: str = "test", + max_requests: int = 5, + window_seconds: float = 60.0, +) -> RedisRateLimiter: + """Build a RedisRateLimiter wired to a _FakeRedis instance.""" + with patch("redis.Redis.from_url", return_value=fake): + return RedisRateLimiter( + namespace=namespace, + max_requests=max_requests, + window_seconds=window_seconds, + redis_url="redis://fake:6379/0", + ) + + +def test_redis_allows_within_limit() -> None: + fake = _FakeRedis() + limiter = _make_redis_limiter(fake, max_requests=5) + for _ in range(5): + assert limiter.is_allowed("client-a") is True + + +def test_redis_blocks_over_limit() -> None: + fake = _FakeRedis() + limiter = _make_redis_limiter(fake, max_requests=3) + for _ in range(3): + assert limiter.is_allowed("client-a") is True + assert limiter.is_allowed("client-a") is False + assert limiter.is_allowed("client-a") is False + + +def test_redis_separate_keys_independent() -> None: + fake = _FakeRedis() + limiter = _make_redis_limiter(fake, max_requests=2) + assert limiter.is_allowed("client-a") is True + assert limiter.is_allowed("client-a") is True + assert limiter.is_allowed("client-a") is False + # Different key still allowed + assert limiter.is_allowed("client-b") is True + assert limiter.is_allowed("client-b") is True + assert limiter.is_allowed("client-b") is False + + +def test_redis_window_expiry() -> None: + fake = _FakeRedis() + limiter = _make_redis_limiter(fake, max_requests=2, window_seconds=1.0) + assert limiter.is_allowed("client-a") is True + assert limiter.is_allowed("client-a") is True + assert limiter.is_allowed("client-a") is False + + # Simulate time passing beyond the window + future = time.time() + 2.0 + with patch("time.time", return_value=future): + assert limiter.is_allowed("client-a") is True + + +def test_redis_fail_open_on_error() -> None: + """When Redis is unreachable, requests should be allowed (fail-open).""" + fake = _FakeRedis() + limiter = _make_redis_limiter(fake, max_requests=1) + + # Make the pipeline raise on execute + def _broken_pipeline(*, transaction: bool = True) -> MagicMock: + pipe = MagicMock() + pipe.zremrangebyscore.return_value = pipe + pipe.zadd.return_value = pipe + pipe.zcard.return_value = pipe + pipe.expire.return_value = pipe + pipe.execute.side_effect = ConnectionError("Redis gone") + return pipe + + limiter._client.pipeline = _broken_pipeline # type: ignore[assignment] + + # Should still allow (fail-open) even though Redis is broken + assert limiter.is_allowed("client-a") is True + assert limiter.is_allowed("client-a") is True # would normally be blocked + + +def test_redis_fail_open_logs_warning() -> None: + """Verify a warning is logged when Redis is unreachable.""" + fake = _FakeRedis() + limiter = _make_redis_limiter(fake, max_requests=1) + + def _broken_pipeline(*, transaction: bool = True) -> MagicMock: + pipe = MagicMock() + pipe.zremrangebyscore.return_value = pipe + pipe.zadd.return_value = pipe + pipe.zcard.return_value = pipe + pipe.expire.return_value = pipe + pipe.execute.side_effect = ConnectionError("Redis gone") + return pipe + + limiter._client.pipeline = _broken_pipeline # type: ignore[assignment] + + with patch("app.core.rate_limit.logger") as mock_logger: + limiter.is_allowed("client-a") + mock_logger.warning.assert_called_once() + + +# --------------------------------------------------------------------------- +# Factory tests +# --------------------------------------------------------------------------- + + +def test_factory_returns_memory_by_default(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("app.core.config.settings.rate_limit_backend", RateLimitBackend.MEMORY) + limiter = create_rate_limiter(namespace="test", max_requests=10, window_seconds=60.0) + assert isinstance(limiter, InMemoryRateLimiter) + + +def test_factory_returns_redis_when_configured(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("app.core.config.settings.rate_limit_backend", RateLimitBackend.REDIS) + monkeypatch.setattr( + "app.core.config.settings.rate_limit_redis_url", "redis://localhost:6379/0" + ) + fake = _FakeRedis() + with patch("redis.Redis.from_url", return_value=fake): + limiter = create_rate_limiter(namespace="test", max_requests=10, window_seconds=60.0) + assert isinstance(limiter, RedisRateLimiter) + + +# --------------------------------------------------------------------------- +# Startup validation tests +# --------------------------------------------------------------------------- + + +def test_validate_redis_succeeds_when_reachable() -> None: + fake = _FakeRedis() + with patch("redis.Redis.from_url", return_value=fake): + validate_rate_limit_redis("redis://localhost:6379/0") + + +def test_validate_redis_raises_on_unreachable() -> None: + mock_client = MagicMock() + mock_client.ping.side_effect = ConnectionError("refused") + with patch("redis.Redis.from_url", return_value=mock_client): + with pytest.raises(ConnectionError, match="unreachable"): + validate_rate_limit_redis("redis://bad:6379/0")