fix(security): Address PR review feedback
This commit is contained in:
3
Makefile
3
Makefile
@@ -122,19 +122,16 @@ backend-migration-check: ## Validate migration graph + reversible path on clean
|
|||||||
LOCAL_AUTH_TOKEN=ci-local-token-ci-local-token-ci-local-token-ci-local-token \
|
LOCAL_AUTH_TOKEN=ci-local-token-ci-local-token-ci-local-token-ci-local-token \
|
||||||
BASE_URL=http://localhost:8000 \
|
BASE_URL=http://localhost:8000 \
|
||||||
DATABASE_URL=postgresql+psycopg://postgres:postgres@localhost:55432/migration_ci \
|
DATABASE_URL=postgresql+psycopg://postgres:postgres@localhost:55432/migration_ci \
|
||||||
BASE_URL=http://localhost:8000 \
|
|
||||||
uv run alembic upgrade head && \
|
uv run alembic upgrade head && \
|
||||||
AUTH_MODE=local \
|
AUTH_MODE=local \
|
||||||
LOCAL_AUTH_TOKEN=ci-local-token-ci-local-token-ci-local-token-ci-local-token \
|
LOCAL_AUTH_TOKEN=ci-local-token-ci-local-token-ci-local-token-ci-local-token \
|
||||||
BASE_URL=http://localhost:8000 \
|
BASE_URL=http://localhost:8000 \
|
||||||
DATABASE_URL=postgresql+psycopg://postgres:postgres@localhost:55432/migration_ci \
|
DATABASE_URL=postgresql+psycopg://postgres:postgres@localhost:55432/migration_ci \
|
||||||
BASE_URL=http://localhost:8000 \
|
|
||||||
uv run alembic downgrade base && \
|
uv run alembic downgrade base && \
|
||||||
AUTH_MODE=local \
|
AUTH_MODE=local \
|
||||||
LOCAL_AUTH_TOKEN=ci-local-token-ci-local-token-ci-local-token-ci-local-token \
|
LOCAL_AUTH_TOKEN=ci-local-token-ci-local-token-ci-local-token-ci-local-token \
|
||||||
BASE_URL=http://localhost:8000 \
|
BASE_URL=http://localhost:8000 \
|
||||||
DATABASE_URL=postgresql+psycopg://postgres:postgres@localhost:55432/migration_ci \
|
DATABASE_URL=postgresql+psycopg://postgres:postgres@localhost:55432/migration_ci \
|
||||||
BASE_URL=http://localhost:8000 \
|
|
||||||
uv run alembic upgrade head
|
uv run alembic upgrade head
|
||||||
|
|
||||||
.PHONY: build
|
.PHONY: build
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import uuid
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
from typing import Awaitable, cast
|
||||||
|
|
||||||
import redis as redis_lib
|
import redis as redis_lib
|
||||||
import redis.asyncio as aioredis
|
import redis.asyncio as aioredis
|
||||||
@@ -24,6 +25,27 @@ logger = get_logger(__name__)
|
|||||||
# Run a full sweep of all keys every 128 calls to is_allowed.
|
# Run a full sweep of all keys every 128 calls to is_allowed.
|
||||||
_CLEANUP_INTERVAL = 128
|
_CLEANUP_INTERVAL = 128
|
||||||
|
|
||||||
|
# Redis sliding-window script that bounds per-key storage to
|
||||||
|
# ``max_requests`` while preserving the current "blocked attempts extend
|
||||||
|
# the window" behavior by retaining the most recent attempts.
|
||||||
|
_REDIS_IS_ALLOWED_SCRIPT = """
|
||||||
|
redis.call("ZREMRANGEBYSCORE", KEYS[1], "-inf", ARGV[1])
|
||||||
|
local count = redis.call("ZCARD", KEYS[1])
|
||||||
|
if count < tonumber(ARGV[4]) then
|
||||||
|
redis.call("ZADD", KEYS[1], ARGV[2], ARGV[3])
|
||||||
|
redis.call("EXPIRE", KEYS[1], ARGV[5])
|
||||||
|
return 1
|
||||||
|
end
|
||||||
|
|
||||||
|
local oldest = redis.call("ZRANGE", KEYS[1], 0, 0)
|
||||||
|
if oldest[1] then
|
||||||
|
redis.call("ZREM", KEYS[1], oldest[1])
|
||||||
|
end
|
||||||
|
redis.call("ZADD", KEYS[1], ARGV[2], ARGV[3])
|
||||||
|
redis.call("EXPIRE", KEYS[1], ARGV[5])
|
||||||
|
return 0
|
||||||
|
"""
|
||||||
|
|
||||||
# Shared async Redis clients keyed by URL to avoid duplicate connection pools.
|
# Shared async Redis clients keyed by URL to avoid duplicate connection pools.
|
||||||
_async_redis_clients: dict[str, aioredis.Redis] = {}
|
_async_redis_clients: dict[str, aioredis.Redis] = {}
|
||||||
|
|
||||||
@@ -81,17 +103,25 @@ class InMemoryRateLimiter(RateLimiter):
|
|||||||
# Prune expired entries from the front (timestamps are monotonic)
|
# Prune expired entries from the front (timestamps are monotonic)
|
||||||
while timestamps and timestamps[0] <= cutoff:
|
while timestamps and timestamps[0] <= cutoff:
|
||||||
timestamps.popleft()
|
timestamps.popleft()
|
||||||
|
if len(timestamps) < self._max_requests:
|
||||||
timestamps.append(now)
|
timestamps.append(now)
|
||||||
return len(timestamps) <= self._max_requests
|
return True
|
||||||
|
|
||||||
|
# Retain only the latest ``max_requests`` attempts so
|
||||||
|
# sustained abuse keeps extending the window without letting
|
||||||
|
# the bucket grow unbounded.
|
||||||
|
timestamps.popleft()
|
||||||
|
timestamps.append(now)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class RedisRateLimiter(RateLimiter):
|
class RedisRateLimiter(RateLimiter):
|
||||||
"""Redis-backed sliding-window rate limiter using sorted sets.
|
"""Redis-backed sliding-window rate limiter using sorted sets.
|
||||||
|
|
||||||
Each key is stored as a Redis sorted set where members are unique
|
Each key is stored as a Redis sorted set where members are unique
|
||||||
request identifiers and scores are wall-clock timestamps. An async
|
request identifiers and scores are wall-clock timestamps. A Lua
|
||||||
pipeline prunes expired entries, adds the new request, counts the
|
script prunes expired entries, updates the set, and keeps storage
|
||||||
window, and sets a TTL — all in a single round-trip.
|
bounded to the most recent ``max_requests`` attempts.
|
||||||
|
|
||||||
Fail-open: if Redis is unreachable during a request, the request is
|
Fail-open: if Redis is unreachable during a request, the request is
|
||||||
allowed and a warning is logged.
|
allowed and a warning is logged.
|
||||||
@@ -118,13 +148,19 @@ class RedisRateLimiter(RateLimiter):
|
|||||||
member = f"{now}:{uuid.uuid4().hex[:8]}"
|
member = f"{now}:{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pipe = self._client.pipeline(transaction=True)
|
allowed = await cast(
|
||||||
pipe.zremrangebyscore(redis_key, "-inf", cutoff)
|
Awaitable[object],
|
||||||
pipe.zadd(redis_key, {member: now})
|
self._client.eval(
|
||||||
pipe.zcard(redis_key)
|
_REDIS_IS_ALLOWED_SCRIPT,
|
||||||
pipe.expire(redis_key, int(self._window_seconds) + 1)
|
1,
|
||||||
results = await pipe.execute()
|
redis_key,
|
||||||
count: int = results[2]
|
str(cutoff),
|
||||||
|
str(now),
|
||||||
|
member,
|
||||||
|
str(self._max_requests),
|
||||||
|
str(int(self._window_seconds) + 1),
|
||||||
|
),
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"rate_limit.redis.unavailable namespace=%s key=%s",
|
"rate_limit.redis.unavailable namespace=%s key=%s",
|
||||||
@@ -134,7 +170,7 @@ class RedisRateLimiter(RateLimiter):
|
|||||||
)
|
)
|
||||||
return True # fail-open
|
return True # fail-open
|
||||||
|
|
||||||
return count <= self._max_requests
|
return bool(allowed)
|
||||||
|
|
||||||
|
|
||||||
def _redact_url(url: str) -> str:
|
def _redact_url(url: str) -> str:
|
||||||
|
|||||||
@@ -15,68 +15,38 @@ from app.core.rate_limit import (
|
|||||||
)
|
)
|
||||||
from app.core.rate_limit_backend import RateLimitBackend
|
from app.core.rate_limit_backend import RateLimitBackend
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Fake Redis helpers for deterministic testing
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class _FakePipeline:
|
|
||||||
"""Minimal sorted-set pipeline that executes against a _FakeRedis."""
|
|
||||||
|
|
||||||
def __init__(self, parent: _FakeRedis) -> None:
|
|
||||||
self._parent = parent
|
|
||||||
self._ops: list[tuple[str, ...]] = []
|
|
||||||
|
|
||||||
# Pipeline command stubs -- each just records intent and returns self
|
|
||||||
# so chaining works (even though our tests don't chain).
|
|
||||||
|
|
||||||
def zremrangebyscore(self, key: str, min_val: str, max_val: float) -> _FakePipeline:
|
|
||||||
self._ops.append(("zremrangebyscore", key, str(min_val), str(max_val)))
|
|
||||||
return self
|
|
||||||
|
|
||||||
def zadd(self, key: str, mapping: dict[str, float]) -> _FakePipeline:
|
|
||||||
self._ops.append(("zadd", key, *next(iter(mapping.items()))))
|
|
||||||
return self
|
|
||||||
|
|
||||||
def zcard(self, key: str) -> _FakePipeline:
|
|
||||||
self._ops.append(("zcard", key))
|
|
||||||
return self
|
|
||||||
|
|
||||||
def expire(self, key: str, seconds: int) -> _FakePipeline:
|
|
||||||
self._ops.append(("expire", key, str(seconds)))
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def execute(self) -> list[object]:
|
|
||||||
results: list[object] = []
|
|
||||||
for op in self._ops:
|
|
||||||
cmd = op[0]
|
|
||||||
key = op[1]
|
|
||||||
zset = self._parent._sorted_sets.setdefault(key, {})
|
|
||||||
if cmd == "zremrangebyscore":
|
|
||||||
max_score = float(op[3])
|
|
||||||
expired = [m for m, s in zset.items() if s <= max_score]
|
|
||||||
for m in expired:
|
|
||||||
del zset[m]
|
|
||||||
results.append(len(expired))
|
|
||||||
elif cmd == "zadd":
|
|
||||||
member, score = op[2], float(op[3])
|
|
||||||
zset[member] = score
|
|
||||||
results.append(1)
|
|
||||||
elif cmd == "zcard":
|
|
||||||
results.append(len(zset))
|
|
||||||
elif cmd == "expire":
|
|
||||||
results.append(True)
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeRedis:
|
class _FakeRedis:
|
||||||
"""Minimal in-process Redis fake supporting sorted-set pipeline ops."""
|
"""Minimal in-process Redis fake supporting the limiter Lua script."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._sorted_sets: dict[str, dict[str, float]] = {}
|
self._sorted_sets: dict[str, dict[str, float]] = {}
|
||||||
|
|
||||||
def pipeline(self, *, transaction: bool = True) -> _FakePipeline:
|
async def eval(
|
||||||
return _FakePipeline(self)
|
self,
|
||||||
|
script: str,
|
||||||
|
numkeys: int,
|
||||||
|
key: str,
|
||||||
|
cutoff: float,
|
||||||
|
now: float,
|
||||||
|
member: str,
|
||||||
|
max_requests: int,
|
||||||
|
ttl: int,
|
||||||
|
) -> int:
|
||||||
|
del script, numkeys, ttl
|
||||||
|
|
||||||
|
zset = self._sorted_sets.setdefault(key, {})
|
||||||
|
expired = [m for m, s in zset.items() if s <= float(cutoff)]
|
||||||
|
for m in expired:
|
||||||
|
del zset[m]
|
||||||
|
|
||||||
|
if len(zset) < int(max_requests):
|
||||||
|
zset[member] = float(now)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
oldest_member = min(zset, key=zset.__getitem__)
|
||||||
|
del zset[oldest_member]
|
||||||
|
zset[member] = float(now)
|
||||||
|
return 0
|
||||||
|
|
||||||
def ping(self) -> bool:
|
def ping(self) -> bool:
|
||||||
return True
|
return True
|
||||||
@@ -106,6 +76,19 @@ async def test_blocks_requests_over_limit() -> None:
|
|||||||
assert await limiter.is_allowed("client-a") is False
|
assert await limiter.is_allowed("client-a") is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_blocked_requests_extend_window_without_growing_memory() -> None:
|
||||||
|
limiter = InMemoryRateLimiter(max_requests=2, window_seconds=1.0)
|
||||||
|
with patch("time.monotonic", side_effect=[0.0, 0.1, 0.2, 1.05, 1.21]):
|
||||||
|
assert await limiter.is_allowed("client-a") is True
|
||||||
|
assert await limiter.is_allowed("client-a") is True
|
||||||
|
assert await limiter.is_allowed("client-a") is False
|
||||||
|
assert await limiter.is_allowed("client-a") is False
|
||||||
|
assert await limiter.is_allowed("client-a") is True
|
||||||
|
|
||||||
|
assert len(limiter._buckets["client-a"]) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio()
|
@pytest.mark.asyncio()
|
||||||
async def test_separate_keys_have_independent_limits() -> None:
|
async def test_separate_keys_have_independent_limits() -> None:
|
||||||
limiter = InMemoryRateLimiter(max_requests=2, window_seconds=60.0)
|
limiter = InMemoryRateLimiter(max_requests=2, window_seconds=60.0)
|
||||||
@@ -198,6 +181,22 @@ async def test_redis_blocks_over_limit() -> None:
|
|||||||
assert await limiter.is_allowed("client-a") is False
|
assert await limiter.is_allowed("client-a") is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_redis_blocked_requests_extend_window_without_growing_storage() -> None:
|
||||||
|
fake = _FakeRedis()
|
||||||
|
limiter = _make_redis_limiter(fake, max_requests=2, window_seconds=1.0)
|
||||||
|
redis_key = "ratelimit:test:client-a"
|
||||||
|
|
||||||
|
with patch("time.time", side_effect=[0.0, 0.1, 0.2, 1.05, 1.21]):
|
||||||
|
assert await limiter.is_allowed("client-a") is True
|
||||||
|
assert await limiter.is_allowed("client-a") is True
|
||||||
|
assert await limiter.is_allowed("client-a") is False
|
||||||
|
assert await limiter.is_allowed("client-a") is False
|
||||||
|
assert await limiter.is_allowed("client-a") is True
|
||||||
|
|
||||||
|
assert len(fake._sorted_sets[redis_key]) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio()
|
@pytest.mark.asyncio()
|
||||||
async def test_redis_separate_keys_independent() -> None:
|
async def test_redis_separate_keys_independent() -> None:
|
||||||
fake = _FakeRedis()
|
fake = _FakeRedis()
|
||||||
@@ -231,17 +230,8 @@ async def test_redis_fail_open_on_error() -> None:
|
|||||||
fake = _FakeRedis()
|
fake = _FakeRedis()
|
||||||
limiter = _make_redis_limiter(fake, max_requests=1)
|
limiter = _make_redis_limiter(fake, max_requests=1)
|
||||||
|
|
||||||
# Make the pipeline raise on execute
|
broken_eval = MagicMock(side_effect=ConnectionError("Redis gone"))
|
||||||
def _broken_pipeline(*, transaction: bool = True) -> MagicMock:
|
limiter._client.eval = broken_eval # type: ignore[assignment]
|
||||||
pipe = MagicMock()
|
|
||||||
pipe.zremrangebyscore.return_value = pipe
|
|
||||||
pipe.zadd.return_value = pipe
|
|
||||||
pipe.zcard.return_value = pipe
|
|
||||||
pipe.expire.return_value = pipe
|
|
||||||
pipe.execute.side_effect = ConnectionError("Redis gone")
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
limiter._client.pipeline = _broken_pipeline # type: ignore[assignment]
|
|
||||||
|
|
||||||
# Should still allow (fail-open) even though Redis is broken
|
# Should still allow (fail-open) even though Redis is broken
|
||||||
assert await limiter.is_allowed("client-a") is True
|
assert await limiter.is_allowed("client-a") is True
|
||||||
@@ -254,16 +244,7 @@ async def test_redis_fail_open_logs_warning() -> None:
|
|||||||
fake = _FakeRedis()
|
fake = _FakeRedis()
|
||||||
limiter = _make_redis_limiter(fake, max_requests=1)
|
limiter = _make_redis_limiter(fake, max_requests=1)
|
||||||
|
|
||||||
def _broken_pipeline(*, transaction: bool = True) -> MagicMock:
|
limiter._client.eval = MagicMock(side_effect=ConnectionError("Redis gone")) # type: ignore[assignment]
|
||||||
pipe = MagicMock()
|
|
||||||
pipe.zremrangebyscore.return_value = pipe
|
|
||||||
pipe.zadd.return_value = pipe
|
|
||||||
pipe.zcard.return_value = pipe
|
|
||||||
pipe.expire.return_value = pipe
|
|
||||||
pipe.execute.side_effect = ConnectionError("Redis gone")
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
limiter._client.pipeline = _broken_pipeline # type: ignore[assignment]
|
|
||||||
|
|
||||||
with patch("app.core.rate_limit.logger") as mock_logger:
|
with patch("app.core.rate_limit.logger") as mock_logger:
|
||||||
await limiter.is_allowed("client-a")
|
await limiter.is_allowed("client-a")
|
||||||
|
|||||||
Reference in New Issue
Block a user