diff --git a/Makefile b/Makefile index eac81243..b1dc5aa1 100644 --- a/Makefile +++ b/Makefile @@ -122,19 +122,16 @@ backend-migration-check: ## Validate migration graph + reversible path on clean LOCAL_AUTH_TOKEN=ci-local-token-ci-local-token-ci-local-token-ci-local-token \ BASE_URL=http://localhost:8000 \ DATABASE_URL=postgresql+psycopg://postgres:postgres@localhost:55432/migration_ci \ - BASE_URL=http://localhost:8000 \ uv run alembic upgrade head && \ AUTH_MODE=local \ LOCAL_AUTH_TOKEN=ci-local-token-ci-local-token-ci-local-token-ci-local-token \ BASE_URL=http://localhost:8000 \ DATABASE_URL=postgresql+psycopg://postgres:postgres@localhost:55432/migration_ci \ - BASE_URL=http://localhost:8000 \ uv run alembic downgrade base && \ AUTH_MODE=local \ LOCAL_AUTH_TOKEN=ci-local-token-ci-local-token-ci-local-token-ci-local-token \ BASE_URL=http://localhost:8000 \ DATABASE_URL=postgresql+psycopg://postgres:postgres@localhost:55432/migration_ci \ - BASE_URL=http://localhost:8000 \ uv run alembic upgrade head .PHONY: build diff --git a/backend/app/core/rate_limit.py b/backend/app/core/rate_limit.py index 3903d9ca..cd4a1e80 100644 --- a/backend/app/core/rate_limit.py +++ b/backend/app/core/rate_limit.py @@ -12,6 +12,7 @@ import uuid from abc import ABC, abstractmethod from collections import deque from threading import Lock +from typing import Awaitable, cast import redis as redis_lib import redis.asyncio as aioredis @@ -24,6 +25,27 @@ logger = get_logger(__name__) # Run a full sweep of all keys every 128 calls to is_allowed. _CLEANUP_INTERVAL = 128 +# Redis sliding-window script that bounds per-key storage to +# ``max_requests`` while preserving the current "blocked attempts extend +# the window" behavior by retaining the most recent attempts. +_REDIS_IS_ALLOWED_SCRIPT = """ +redis.call("ZREMRANGEBYSCORE", KEYS[1], "-inf", ARGV[1]) +local count = redis.call("ZCARD", KEYS[1]) +if count < tonumber(ARGV[4]) then + redis.call("ZADD", KEYS[1], ARGV[2], ARGV[3]) + redis.call("EXPIRE", KEYS[1], ARGV[5]) + return 1 +end + +local oldest = redis.call("ZRANGE", KEYS[1], 0, 0) +if oldest[1] then + redis.call("ZREM", KEYS[1], oldest[1]) +end +redis.call("ZADD", KEYS[1], ARGV[2], ARGV[3]) +redis.call("EXPIRE", KEYS[1], ARGV[5]) +return 0 +""" + # Shared async Redis clients keyed by URL to avoid duplicate connection pools. _async_redis_clients: dict[str, aioredis.Redis] = {} @@ -81,17 +103,25 @@ class InMemoryRateLimiter(RateLimiter): # Prune expired entries from the front (timestamps are monotonic) while timestamps and timestamps[0] <= cutoff: timestamps.popleft() + if len(timestamps) < self._max_requests: + timestamps.append(now) + return True + + # Retain only the latest ``max_requests`` attempts so + # sustained abuse keeps extending the window without letting + # the bucket grow unbounded. + timestamps.popleft() timestamps.append(now) - return len(timestamps) <= self._max_requests + return False class RedisRateLimiter(RateLimiter): """Redis-backed sliding-window rate limiter using sorted sets. Each key is stored as a Redis sorted set where members are unique - request identifiers and scores are wall-clock timestamps. An async - pipeline prunes expired entries, adds the new request, counts the - window, and sets a TTL — all in a single round-trip. + request identifiers and scores are wall-clock timestamps. A Lua + script prunes expired entries, updates the set, and keeps storage + bounded to the most recent ``max_requests`` attempts. Fail-open: if Redis is unreachable during a request, the request is allowed and a warning is logged. @@ -118,13 +148,19 @@ class RedisRateLimiter(RateLimiter): member = f"{now}:{uuid.uuid4().hex[:8]}" try: - pipe = self._client.pipeline(transaction=True) - pipe.zremrangebyscore(redis_key, "-inf", cutoff) - pipe.zadd(redis_key, {member: now}) - pipe.zcard(redis_key) - pipe.expire(redis_key, int(self._window_seconds) + 1) - results = await pipe.execute() - count: int = results[2] + allowed = await cast( + Awaitable[object], + self._client.eval( + _REDIS_IS_ALLOWED_SCRIPT, + 1, + redis_key, + str(cutoff), + str(now), + member, + str(self._max_requests), + str(int(self._window_seconds) + 1), + ), + ) except Exception: logger.warning( "rate_limit.redis.unavailable namespace=%s key=%s", @@ -134,7 +170,7 @@ class RedisRateLimiter(RateLimiter): ) return True # fail-open - return count <= self._max_requests + return bool(allowed) def _redact_url(url: str) -> str: diff --git a/backend/tests/test_rate_limit.py b/backend/tests/test_rate_limit.py index 5741ee12..aac40ae8 100644 --- a/backend/tests/test_rate_limit.py +++ b/backend/tests/test_rate_limit.py @@ -15,68 +15,38 @@ from app.core.rate_limit import ( ) from app.core.rate_limit_backend import RateLimitBackend -# --------------------------------------------------------------------------- -# Fake Redis helpers for deterministic testing -# --------------------------------------------------------------------------- - - -class _FakePipeline: - """Minimal sorted-set pipeline that executes against a _FakeRedis.""" - - def __init__(self, parent: _FakeRedis) -> None: - self._parent = parent - self._ops: list[tuple[str, ...]] = [] - - # Pipeline command stubs -- each just records intent and returns self - # so chaining works (even though our tests don't chain). - - def zremrangebyscore(self, key: str, min_val: str, max_val: float) -> _FakePipeline: - self._ops.append(("zremrangebyscore", key, str(min_val), str(max_val))) - return self - - def zadd(self, key: str, mapping: dict[str, float]) -> _FakePipeline: - self._ops.append(("zadd", key, *next(iter(mapping.items())))) - return self - - def zcard(self, key: str) -> _FakePipeline: - self._ops.append(("zcard", key)) - return self - - def expire(self, key: str, seconds: int) -> _FakePipeline: - self._ops.append(("expire", key, str(seconds))) - return self - - async def execute(self) -> list[object]: - results: list[object] = [] - for op in self._ops: - cmd = op[0] - key = op[1] - zset = self._parent._sorted_sets.setdefault(key, {}) - if cmd == "zremrangebyscore": - max_score = float(op[3]) - expired = [m for m, s in zset.items() if s <= max_score] - for m in expired: - del zset[m] - results.append(len(expired)) - elif cmd == "zadd": - member, score = op[2], float(op[3]) - zset[member] = score - results.append(1) - elif cmd == "zcard": - results.append(len(zset)) - elif cmd == "expire": - results.append(True) - return results - - class _FakeRedis: - """Minimal in-process Redis fake supporting sorted-set pipeline ops.""" + """Minimal in-process Redis fake supporting the limiter Lua script.""" def __init__(self) -> None: self._sorted_sets: dict[str, dict[str, float]] = {} - def pipeline(self, *, transaction: bool = True) -> _FakePipeline: - return _FakePipeline(self) + async def eval( + self, + script: str, + numkeys: int, + key: str, + cutoff: float, + now: float, + member: str, + max_requests: int, + ttl: int, + ) -> int: + del script, numkeys, ttl + + zset = self._sorted_sets.setdefault(key, {}) + expired = [m for m, s in zset.items() if s <= float(cutoff)] + for m in expired: + del zset[m] + + if len(zset) < int(max_requests): + zset[member] = float(now) + return 1 + + oldest_member = min(zset, key=zset.__getitem__) + del zset[oldest_member] + zset[member] = float(now) + return 0 def ping(self) -> bool: return True @@ -106,6 +76,19 @@ async def test_blocks_requests_over_limit() -> None: assert await limiter.is_allowed("client-a") is False +@pytest.mark.asyncio() +async def test_blocked_requests_extend_window_without_growing_memory() -> None: + limiter = InMemoryRateLimiter(max_requests=2, window_seconds=1.0) + with patch("time.monotonic", side_effect=[0.0, 0.1, 0.2, 1.05, 1.21]): + assert await limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is False + assert await limiter.is_allowed("client-a") is False + assert await limiter.is_allowed("client-a") is True + + assert len(limiter._buckets["client-a"]) == 2 + + @pytest.mark.asyncio() async def test_separate_keys_have_independent_limits() -> None: limiter = InMemoryRateLimiter(max_requests=2, window_seconds=60.0) @@ -198,6 +181,22 @@ async def test_redis_blocks_over_limit() -> None: assert await limiter.is_allowed("client-a") is False +@pytest.mark.asyncio() +async def test_redis_blocked_requests_extend_window_without_growing_storage() -> None: + fake = _FakeRedis() + limiter = _make_redis_limiter(fake, max_requests=2, window_seconds=1.0) + redis_key = "ratelimit:test:client-a" + + with patch("time.time", side_effect=[0.0, 0.1, 0.2, 1.05, 1.21]): + assert await limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is True + assert await limiter.is_allowed("client-a") is False + assert await limiter.is_allowed("client-a") is False + assert await limiter.is_allowed("client-a") is True + + assert len(fake._sorted_sets[redis_key]) == 2 + + @pytest.mark.asyncio() async def test_redis_separate_keys_independent() -> None: fake = _FakeRedis() @@ -231,17 +230,8 @@ async def test_redis_fail_open_on_error() -> None: fake = _FakeRedis() limiter = _make_redis_limiter(fake, max_requests=1) - # Make the pipeline raise on execute - def _broken_pipeline(*, transaction: bool = True) -> MagicMock: - pipe = MagicMock() - pipe.zremrangebyscore.return_value = pipe - pipe.zadd.return_value = pipe - pipe.zcard.return_value = pipe - pipe.expire.return_value = pipe - pipe.execute.side_effect = ConnectionError("Redis gone") - return pipe - - limiter._client.pipeline = _broken_pipeline # type: ignore[assignment] + broken_eval = MagicMock(side_effect=ConnectionError("Redis gone")) + limiter._client.eval = broken_eval # type: ignore[assignment] # Should still allow (fail-open) even though Redis is broken assert await limiter.is_allowed("client-a") is True @@ -254,16 +244,7 @@ async def test_redis_fail_open_logs_warning() -> None: fake = _FakeRedis() limiter = _make_redis_limiter(fake, max_requests=1) - def _broken_pipeline(*, transaction: bool = True) -> MagicMock: - pipe = MagicMock() - pipe.zremrangebyscore.return_value = pipe - pipe.zadd.return_value = pipe - pipe.zcard.return_value = pipe - pipe.expire.return_value = pipe - pipe.execute.side_effect = ConnectionError("Redis gone") - return pipe - - limiter._client.pipeline = _broken_pipeline # type: ignore[assignment] + limiter._client.eval = MagicMock(side_effect=ConnectionError("Redis gone")) # type: ignore[assignment] with patch("app.core.rate_limit.logger") as mock_logger: await limiter.is_allowed("client-a")