fix(security): Address PR review feedback

This commit is contained in:
Abhimanyu Saharan
2026-03-08 00:01:04 +05:30
parent b3cb604776
commit cc3024acc3
3 changed files with 107 additions and 93 deletions

View File

@@ -122,19 +122,16 @@ backend-migration-check: ## Validate migration graph + reversible path on clean
LOCAL_AUTH_TOKEN=ci-local-token-ci-local-token-ci-local-token-ci-local-token \ LOCAL_AUTH_TOKEN=ci-local-token-ci-local-token-ci-local-token-ci-local-token \
BASE_URL=http://localhost:8000 \ BASE_URL=http://localhost:8000 \
DATABASE_URL=postgresql+psycopg://postgres:postgres@localhost:55432/migration_ci \ DATABASE_URL=postgresql+psycopg://postgres:postgres@localhost:55432/migration_ci \
BASE_URL=http://localhost:8000 \
uv run alembic upgrade head && \ uv run alembic upgrade head && \
AUTH_MODE=local \ AUTH_MODE=local \
LOCAL_AUTH_TOKEN=ci-local-token-ci-local-token-ci-local-token-ci-local-token \ LOCAL_AUTH_TOKEN=ci-local-token-ci-local-token-ci-local-token-ci-local-token \
BASE_URL=http://localhost:8000 \ BASE_URL=http://localhost:8000 \
DATABASE_URL=postgresql+psycopg://postgres:postgres@localhost:55432/migration_ci \ DATABASE_URL=postgresql+psycopg://postgres:postgres@localhost:55432/migration_ci \
BASE_URL=http://localhost:8000 \
uv run alembic downgrade base && \ uv run alembic downgrade base && \
AUTH_MODE=local \ AUTH_MODE=local \
LOCAL_AUTH_TOKEN=ci-local-token-ci-local-token-ci-local-token-ci-local-token \ LOCAL_AUTH_TOKEN=ci-local-token-ci-local-token-ci-local-token-ci-local-token \
BASE_URL=http://localhost:8000 \ BASE_URL=http://localhost:8000 \
DATABASE_URL=postgresql+psycopg://postgres:postgres@localhost:55432/migration_ci \ DATABASE_URL=postgresql+psycopg://postgres:postgres@localhost:55432/migration_ci \
BASE_URL=http://localhost:8000 \
uv run alembic upgrade head uv run alembic upgrade head
.PHONY: build .PHONY: build

View File

@@ -12,6 +12,7 @@ import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import deque from collections import deque
from threading import Lock from threading import Lock
from typing import Awaitable, cast
import redis as redis_lib import redis as redis_lib
import redis.asyncio as aioredis import redis.asyncio as aioredis
@@ -24,6 +25,27 @@ logger = get_logger(__name__)
# Run a full sweep of all keys every 128 calls to is_allowed. # Run a full sweep of all keys every 128 calls to is_allowed.
_CLEANUP_INTERVAL = 128 _CLEANUP_INTERVAL = 128
# Redis sliding-window script that bounds per-key storage to
# ``max_requests`` while preserving the current "blocked attempts extend
# the window" behavior by retaining the most recent attempts.
_REDIS_IS_ALLOWED_SCRIPT = """
redis.call("ZREMRANGEBYSCORE", KEYS[1], "-inf", ARGV[1])
local count = redis.call("ZCARD", KEYS[1])
if count < tonumber(ARGV[4]) then
redis.call("ZADD", KEYS[1], ARGV[2], ARGV[3])
redis.call("EXPIRE", KEYS[1], ARGV[5])
return 1
end
local oldest = redis.call("ZRANGE", KEYS[1], 0, 0)
if oldest[1] then
redis.call("ZREM", KEYS[1], oldest[1])
end
redis.call("ZADD", KEYS[1], ARGV[2], ARGV[3])
redis.call("EXPIRE", KEYS[1], ARGV[5])
return 0
"""
# Shared async Redis clients keyed by URL to avoid duplicate connection pools. # Shared async Redis clients keyed by URL to avoid duplicate connection pools.
_async_redis_clients: dict[str, aioredis.Redis] = {} _async_redis_clients: dict[str, aioredis.Redis] = {}
@@ -81,17 +103,25 @@ class InMemoryRateLimiter(RateLimiter):
# Prune expired entries from the front (timestamps are monotonic) # Prune expired entries from the front (timestamps are monotonic)
while timestamps and timestamps[0] <= cutoff: while timestamps and timestamps[0] <= cutoff:
timestamps.popleft() timestamps.popleft()
if len(timestamps) < self._max_requests:
timestamps.append(now) timestamps.append(now)
return len(timestamps) <= self._max_requests return True
# Retain only the latest ``max_requests`` attempts so
# sustained abuse keeps extending the window without letting
# the bucket grow unbounded.
timestamps.popleft()
timestamps.append(now)
return False
class RedisRateLimiter(RateLimiter): class RedisRateLimiter(RateLimiter):
"""Redis-backed sliding-window rate limiter using sorted sets. """Redis-backed sliding-window rate limiter using sorted sets.
Each key is stored as a Redis sorted set where members are unique Each key is stored as a Redis sorted set where members are unique
request identifiers and scores are wall-clock timestamps. An async request identifiers and scores are wall-clock timestamps. A Lua
pipeline prunes expired entries, adds the new request, counts the script prunes expired entries, updates the set, and keeps storage
window, and sets a TTL — all in a single round-trip. bounded to the most recent ``max_requests`` attempts.
Fail-open: if Redis is unreachable during a request, the request is Fail-open: if Redis is unreachable during a request, the request is
allowed and a warning is logged. allowed and a warning is logged.
@@ -118,13 +148,19 @@ class RedisRateLimiter(RateLimiter):
member = f"{now}:{uuid.uuid4().hex[:8]}" member = f"{now}:{uuid.uuid4().hex[:8]}"
try: try:
pipe = self._client.pipeline(transaction=True) allowed = await cast(
pipe.zremrangebyscore(redis_key, "-inf", cutoff) Awaitable[object],
pipe.zadd(redis_key, {member: now}) self._client.eval(
pipe.zcard(redis_key) _REDIS_IS_ALLOWED_SCRIPT,
pipe.expire(redis_key, int(self._window_seconds) + 1) 1,
results = await pipe.execute() redis_key,
count: int = results[2] str(cutoff),
str(now),
member,
str(self._max_requests),
str(int(self._window_seconds) + 1),
),
)
except Exception: except Exception:
logger.warning( logger.warning(
"rate_limit.redis.unavailable namespace=%s key=%s", "rate_limit.redis.unavailable namespace=%s key=%s",
@@ -134,7 +170,7 @@ class RedisRateLimiter(RateLimiter):
) )
return True # fail-open return True # fail-open
return count <= self._max_requests return bool(allowed)
def _redact_url(url: str) -> str: def _redact_url(url: str) -> str:

View File

@@ -15,68 +15,38 @@ from app.core.rate_limit import (
) )
from app.core.rate_limit_backend import RateLimitBackend from app.core.rate_limit_backend import RateLimitBackend
# ---------------------------------------------------------------------------
# Fake Redis helpers for deterministic testing
# ---------------------------------------------------------------------------
class _FakePipeline:
"""Minimal sorted-set pipeline that executes against a _FakeRedis."""
def __init__(self, parent: _FakeRedis) -> None:
self._parent = parent
self._ops: list[tuple[str, ...]] = []
# Pipeline command stubs -- each just records intent and returns self
# so chaining works (even though our tests don't chain).
def zremrangebyscore(self, key: str, min_val: str, max_val: float) -> _FakePipeline:
self._ops.append(("zremrangebyscore", key, str(min_val), str(max_val)))
return self
def zadd(self, key: str, mapping: dict[str, float]) -> _FakePipeline:
self._ops.append(("zadd", key, *next(iter(mapping.items()))))
return self
def zcard(self, key: str) -> _FakePipeline:
self._ops.append(("zcard", key))
return self
def expire(self, key: str, seconds: int) -> _FakePipeline:
self._ops.append(("expire", key, str(seconds)))
return self
async def execute(self) -> list[object]:
results: list[object] = []
for op in self._ops:
cmd = op[0]
key = op[1]
zset = self._parent._sorted_sets.setdefault(key, {})
if cmd == "zremrangebyscore":
max_score = float(op[3])
expired = [m for m, s in zset.items() if s <= max_score]
for m in expired:
del zset[m]
results.append(len(expired))
elif cmd == "zadd":
member, score = op[2], float(op[3])
zset[member] = score
results.append(1)
elif cmd == "zcard":
results.append(len(zset))
elif cmd == "expire":
results.append(True)
return results
class _FakeRedis: class _FakeRedis:
"""Minimal in-process Redis fake supporting sorted-set pipeline ops.""" """Minimal in-process Redis fake supporting the limiter Lua script."""
def __init__(self) -> None: def __init__(self) -> None:
self._sorted_sets: dict[str, dict[str, float]] = {} self._sorted_sets: dict[str, dict[str, float]] = {}
def pipeline(self, *, transaction: bool = True) -> _FakePipeline: async def eval(
return _FakePipeline(self) self,
script: str,
numkeys: int,
key: str,
cutoff: float,
now: float,
member: str,
max_requests: int,
ttl: int,
) -> int:
del script, numkeys, ttl
zset = self._sorted_sets.setdefault(key, {})
expired = [m for m, s in zset.items() if s <= float(cutoff)]
for m in expired:
del zset[m]
if len(zset) < int(max_requests):
zset[member] = float(now)
return 1
oldest_member = min(zset, key=zset.__getitem__)
del zset[oldest_member]
zset[member] = float(now)
return 0
def ping(self) -> bool: def ping(self) -> bool:
return True return True
@@ -106,6 +76,19 @@ async def test_blocks_requests_over_limit() -> None:
assert await limiter.is_allowed("client-a") is False assert await limiter.is_allowed("client-a") is False
@pytest.mark.asyncio()
async def test_blocked_requests_extend_window_without_growing_memory() -> None:
limiter = InMemoryRateLimiter(max_requests=2, window_seconds=1.0)
with patch("time.monotonic", side_effect=[0.0, 0.1, 0.2, 1.05, 1.21]):
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is False
assert await limiter.is_allowed("client-a") is False
assert await limiter.is_allowed("client-a") is True
assert len(limiter._buckets["client-a"]) == 2
@pytest.mark.asyncio() @pytest.mark.asyncio()
async def test_separate_keys_have_independent_limits() -> None: async def test_separate_keys_have_independent_limits() -> None:
limiter = InMemoryRateLimiter(max_requests=2, window_seconds=60.0) limiter = InMemoryRateLimiter(max_requests=2, window_seconds=60.0)
@@ -198,6 +181,22 @@ async def test_redis_blocks_over_limit() -> None:
assert await limiter.is_allowed("client-a") is False assert await limiter.is_allowed("client-a") is False
@pytest.mark.asyncio()
async def test_redis_blocked_requests_extend_window_without_growing_storage() -> None:
fake = _FakeRedis()
limiter = _make_redis_limiter(fake, max_requests=2, window_seconds=1.0)
redis_key = "ratelimit:test:client-a"
with patch("time.time", side_effect=[0.0, 0.1, 0.2, 1.05, 1.21]):
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is False
assert await limiter.is_allowed("client-a") is False
assert await limiter.is_allowed("client-a") is True
assert len(fake._sorted_sets[redis_key]) == 2
@pytest.mark.asyncio() @pytest.mark.asyncio()
async def test_redis_separate_keys_independent() -> None: async def test_redis_separate_keys_independent() -> None:
fake = _FakeRedis() fake = _FakeRedis()
@@ -231,17 +230,8 @@ async def test_redis_fail_open_on_error() -> None:
fake = _FakeRedis() fake = _FakeRedis()
limiter = _make_redis_limiter(fake, max_requests=1) limiter = _make_redis_limiter(fake, max_requests=1)
# Make the pipeline raise on execute broken_eval = MagicMock(side_effect=ConnectionError("Redis gone"))
def _broken_pipeline(*, transaction: bool = True) -> MagicMock: limiter._client.eval = broken_eval # type: ignore[assignment]
pipe = MagicMock()
pipe.zremrangebyscore.return_value = pipe
pipe.zadd.return_value = pipe
pipe.zcard.return_value = pipe
pipe.expire.return_value = pipe
pipe.execute.side_effect = ConnectionError("Redis gone")
return pipe
limiter._client.pipeline = _broken_pipeline # type: ignore[assignment]
# Should still allow (fail-open) even though Redis is broken # Should still allow (fail-open) even though Redis is broken
assert await limiter.is_allowed("client-a") is True assert await limiter.is_allowed("client-a") is True
@@ -254,16 +244,7 @@ async def test_redis_fail_open_logs_warning() -> None:
fake = _FakeRedis() fake = _FakeRedis()
limiter = _make_redis_limiter(fake, max_requests=1) limiter = _make_redis_limiter(fake, max_requests=1)
def _broken_pipeline(*, transaction: bool = True) -> MagicMock: limiter._client.eval = MagicMock(side_effect=ConnectionError("Redis gone")) # type: ignore[assignment]
pipe = MagicMock()
pipe.zremrangebyscore.return_value = pipe
pipe.zadd.return_value = pipe
pipe.zcard.return_value = pipe
pipe.expire.return_value = pipe
pipe.execute.side_effect = ConnectionError("Redis gone")
return pipe
limiter._client.pipeline = _broken_pipeline # type: ignore[assignment]
with patch("app.core.rate_limit.logger") as mock_logger: with patch("app.core.rate_limit.logger") as mock_logger:
await limiter.is_allowed("client-a") await limiter.is_allowed("client-a")