fix(security): Address PR review feedback

This commit is contained in:
Abhimanyu Saharan
2026-03-08 00:01:04 +05:30
parent b3cb604776
commit cc3024acc3
3 changed files with 107 additions and 93 deletions

View File

@@ -15,68 +15,38 @@ from app.core.rate_limit import (
)
from app.core.rate_limit_backend import RateLimitBackend
# ---------------------------------------------------------------------------
# Fake Redis helpers for deterministic testing
# ---------------------------------------------------------------------------
class _FakePipeline:
"""Minimal sorted-set pipeline that executes against a _FakeRedis."""
def __init__(self, parent: _FakeRedis) -> None:
self._parent = parent
self._ops: list[tuple[str, ...]] = []
# Pipeline command stubs -- each just records intent and returns self
# so chaining works (even though our tests don't chain).
def zremrangebyscore(self, key: str, min_val: str, max_val: float) -> _FakePipeline:
self._ops.append(("zremrangebyscore", key, str(min_val), str(max_val)))
return self
def zadd(self, key: str, mapping: dict[str, float]) -> _FakePipeline:
self._ops.append(("zadd", key, *next(iter(mapping.items()))))
return self
def zcard(self, key: str) -> _FakePipeline:
self._ops.append(("zcard", key))
return self
def expire(self, key: str, seconds: int) -> _FakePipeline:
self._ops.append(("expire", key, str(seconds)))
return self
async def execute(self) -> list[object]:
results: list[object] = []
for op in self._ops:
cmd = op[0]
key = op[1]
zset = self._parent._sorted_sets.setdefault(key, {})
if cmd == "zremrangebyscore":
max_score = float(op[3])
expired = [m for m, s in zset.items() if s <= max_score]
for m in expired:
del zset[m]
results.append(len(expired))
elif cmd == "zadd":
member, score = op[2], float(op[3])
zset[member] = score
results.append(1)
elif cmd == "zcard":
results.append(len(zset))
elif cmd == "expire":
results.append(True)
return results
class _FakeRedis:
"""Minimal in-process Redis fake supporting sorted-set pipeline ops."""
"""Minimal in-process Redis fake supporting the limiter Lua script."""
def __init__(self) -> None:
self._sorted_sets: dict[str, dict[str, float]] = {}
def pipeline(self, *, transaction: bool = True) -> _FakePipeline:
return _FakePipeline(self)
async def eval(
self,
script: str,
numkeys: int,
key: str,
cutoff: float,
now: float,
member: str,
max_requests: int,
ttl: int,
) -> int:
del script, numkeys, ttl
zset = self._sorted_sets.setdefault(key, {})
expired = [m for m, s in zset.items() if s <= float(cutoff)]
for m in expired:
del zset[m]
if len(zset) < int(max_requests):
zset[member] = float(now)
return 1
oldest_member = min(zset, key=zset.__getitem__)
del zset[oldest_member]
zset[member] = float(now)
return 0
def ping(self) -> bool:
return True
@@ -106,6 +76,19 @@ async def test_blocks_requests_over_limit() -> None:
assert await limiter.is_allowed("client-a") is False
@pytest.mark.asyncio()
async def test_blocked_requests_extend_window_without_growing_memory() -> None:
limiter = InMemoryRateLimiter(max_requests=2, window_seconds=1.0)
with patch("time.monotonic", side_effect=[0.0, 0.1, 0.2, 1.05, 1.21]):
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is False
assert await limiter.is_allowed("client-a") is False
assert await limiter.is_allowed("client-a") is True
assert len(limiter._buckets["client-a"]) == 2
@pytest.mark.asyncio()
async def test_separate_keys_have_independent_limits() -> None:
limiter = InMemoryRateLimiter(max_requests=2, window_seconds=60.0)
@@ -198,6 +181,22 @@ async def test_redis_blocks_over_limit() -> None:
assert await limiter.is_allowed("client-a") is False
@pytest.mark.asyncio()
async def test_redis_blocked_requests_extend_window_without_growing_storage() -> None:
fake = _FakeRedis()
limiter = _make_redis_limiter(fake, max_requests=2, window_seconds=1.0)
redis_key = "ratelimit:test:client-a"
with patch("time.time", side_effect=[0.0, 0.1, 0.2, 1.05, 1.21]):
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is True
assert await limiter.is_allowed("client-a") is False
assert await limiter.is_allowed("client-a") is False
assert await limiter.is_allowed("client-a") is True
assert len(fake._sorted_sets[redis_key]) == 2
@pytest.mark.asyncio()
async def test_redis_separate_keys_independent() -> None:
fake = _FakeRedis()
@@ -231,17 +230,8 @@ async def test_redis_fail_open_on_error() -> None:
fake = _FakeRedis()
limiter = _make_redis_limiter(fake, max_requests=1)
# Make the pipeline raise on execute
def _broken_pipeline(*, transaction: bool = True) -> MagicMock:
pipe = MagicMock()
pipe.zremrangebyscore.return_value = pipe
pipe.zadd.return_value = pipe
pipe.zcard.return_value = pipe
pipe.expire.return_value = pipe
pipe.execute.side_effect = ConnectionError("Redis gone")
return pipe
limiter._client.pipeline = _broken_pipeline # type: ignore[assignment]
broken_eval = MagicMock(side_effect=ConnectionError("Redis gone"))
limiter._client.eval = broken_eval # type: ignore[assignment]
# Should still allow (fail-open) even though Redis is broken
assert await limiter.is_allowed("client-a") is True
@@ -254,16 +244,7 @@ async def test_redis_fail_open_logs_warning() -> None:
fake = _FakeRedis()
limiter = _make_redis_limiter(fake, max_requests=1)
def _broken_pipeline(*, transaction: bool = True) -> MagicMock:
pipe = MagicMock()
pipe.zremrangebyscore.return_value = pipe
pipe.zadd.return_value = pipe
pipe.zcard.return_value = pipe
pipe.expire.return_value = pipe
pipe.execute.side_effect = ConnectionError("Redis gone")
return pipe
limiter._client.pipeline = _broken_pipeline # type: ignore[assignment]
limiter._client.eval = MagicMock(side_effect=ConnectionError("Redis gone")) # type: ignore[assignment]
with patch("app.core.rate_limit.logger") as mock_logger:
await limiter.is_allowed("client-a")