feat: add trusted client-IP extraction from proxy headers
Add get_client_ip() helper that inspects Forwarded and X-Forwarded-For headers only when the direct peer is in TRUSTED_PROXIES (comma-separated IPs/CIDRs). Replaces raw request.client.host in rate-limit and webhook source_ip to prevent all traffic collapsing behind a reverse proxy IP. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
committed by
Abhimanyu Saharan
parent
24e40f1153
commit
f1bcf72810
@@ -12,6 +12,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
|
|||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
|
|
||||||
from app.api.deps import get_board_for_user_read, get_board_for_user_write, get_board_or_404
|
from app.api.deps import get_board_for_user_read, get_board_for_user_write, get_board_or_404
|
||||||
|
from app.core.client_ip import get_client_ip
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.core.rate_limit import webhook_ingest_limiter
|
from app.core.rate_limit import webhook_ingest_limiter
|
||||||
@@ -507,7 +508,7 @@ async def ingest_board_webhook(
|
|||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
) -> BoardWebhookIngestResponse:
|
) -> BoardWebhookIngestResponse:
|
||||||
"""Open inbound webhook endpoint that stores payloads and nudges the board lead."""
|
"""Open inbound webhook endpoint that stores payloads and nudges the board lead."""
|
||||||
client_ip = request.client.host if request.client else "unknown"
|
client_ip = get_client_ip(request)
|
||||||
if not webhook_ingest_limiter.is_allowed(client_ip):
|
if not webhook_ingest_limiter.is_allowed(client_ip):
|
||||||
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS)
|
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS)
|
||||||
webhook = await _require_board_webhook(
|
webhook = await _require_board_webhook(
|
||||||
@@ -520,7 +521,7 @@ async def ingest_board_webhook(
|
|||||||
extra={
|
extra={
|
||||||
"board_id": str(board.id),
|
"board_id": str(board.id),
|
||||||
"webhook_id": str(webhook.id),
|
"webhook_id": str(webhook.id),
|
||||||
"source_ip": request.client.host if request.client else None,
|
"source_ip": client_ip,
|
||||||
"content_type": request.headers.get("content-type"),
|
"content_type": request.headers.get("content-type"),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -568,7 +569,7 @@ async def ingest_board_webhook(
|
|||||||
webhook_id=webhook.id,
|
webhook_id=webhook.id,
|
||||||
payload=payload_value,
|
payload=payload_value,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
source_ip=request.client.host if request.client else None,
|
source_ip=client_ip,
|
||||||
content_type=content_type,
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
session.add(payload)
|
session.add(payload)
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from fastapi import Depends, Header, HTTPException, Request, status
|
|||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
|
|
||||||
from app.core.agent_tokens import verify_agent_token
|
from app.core.agent_tokens import verify_agent_token
|
||||||
|
from app.core.client_ip import get_client_ip
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.core.rate_limit import agent_auth_limiter
|
from app.core.rate_limit import agent_auth_limiter
|
||||||
from app.core.time import utcnow
|
from app.core.time import utcnow
|
||||||
@@ -113,7 +114,7 @@ async def get_agent_auth_context(
|
|||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
) -> AgentAuthContext:
|
) -> AgentAuthContext:
|
||||||
"""Require and validate agent auth token from request headers."""
|
"""Require and validate agent auth token from request headers."""
|
||||||
client_ip = request.client.host if request.client else "unknown"
|
client_ip = get_client_ip(request)
|
||||||
if not agent_auth_limiter.is_allowed(client_ip):
|
if not agent_auth_limiter.is_allowed(client_ip):
|
||||||
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS)
|
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS)
|
||||||
resolved = _resolve_agent_token(
|
resolved = _resolve_agent_token(
|
||||||
@@ -174,7 +175,7 @@ async def get_agent_auth_context_optional(
|
|||||||
# guessing via the optional auth path. Scoped to X-Agent-Token so that
|
# guessing via the optional auth path. Scoped to X-Agent-Token so that
|
||||||
# normal user Authorization headers are not throttled.
|
# normal user Authorization headers are not throttled.
|
||||||
if agent_token:
|
if agent_token:
|
||||||
client_ip = request.client.host if request.client else "unknown"
|
client_ip = get_client_ip(request)
|
||||||
if not agent_auth_limiter.is_allowed(client_ip):
|
if not agent_auth_limiter.is_allowed(client_ip):
|
||||||
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS)
|
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS)
|
||||||
agent = await _find_agent_for_token(session, resolved)
|
agent = await _find_agent_for_token(session, resolved)
|
||||||
|
|||||||
121
backend/app/core/client_ip.py
Normal file
121
backend/app/core/client_ip.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""Trusted client-IP extraction from proxy headers.
|
||||||
|
|
||||||
|
In proxied deployments ``request.client.host`` returns the reverse-proxy
|
||||||
|
IP. This module provides :func:`get_client_ip` which inspects the
|
||||||
|
``Forwarded`` and ``X-Forwarded-For`` headers **only** when the direct
|
||||||
|
peer is in the configured ``TRUSTED_PROXIES`` list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
|
import re
|
||||||
|
from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# RFC 7239 ``for=`` directive value. Handles quoted/unquoted, optional
|
||||||
|
# port, and IPv6 bracket notation.
|
||||||
|
_FORWARDED_FOR_RE = re.compile(r'for="?(\[[^\]]+\]|[^";,\s]+)', re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_trusted_networks(raw: str) -> list[IPv4Network | IPv6Network]:
|
||||||
|
"""Parse a comma-separated list of IPs/CIDRs into network objects."""
|
||||||
|
networks: list[IPv4Network | IPv6Network] = []
|
||||||
|
for entry in raw.split(","):
|
||||||
|
entry = entry.strip()
|
||||||
|
if not entry:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
networks.append(ipaddress.ip_network(entry, strict=False))
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("trusted_proxies: ignoring invalid entry %r", entry)
|
||||||
|
return networks
|
||||||
|
|
||||||
|
|
||||||
|
def _is_trusted(peer_ip: str, networks: list[IPv4Network | IPv6Network]) -> bool:
|
||||||
|
"""Check whether *peer_ip* falls within any trusted network."""
|
||||||
|
try:
|
||||||
|
addr = ipaddress.ip_address(peer_ip)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return any(addr in net for net in networks)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_port(raw: str) -> str:
|
||||||
|
"""Strip port suffix from a ``Forwarded: for=`` value.
|
||||||
|
|
||||||
|
Handles ``1.2.3.4:8080``, ``[::1]:8080``, and bare addresses.
|
||||||
|
"""
|
||||||
|
# Bracketed IPv6: ``[::1]:port`` or ``[::1]``
|
||||||
|
if raw.startswith("["):
|
||||||
|
bracket_end = raw.find("]")
|
||||||
|
if bracket_end != -1:
|
||||||
|
return raw[1:bracket_end]
|
||||||
|
return raw.strip("[]")
|
||||||
|
# IPv4 with port: ``1.2.3.4:8080``
|
||||||
|
if "." in raw and ":" in raw:
|
||||||
|
return raw.rsplit(":", 1)[0]
|
||||||
|
return raw
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_from_forwarded(header: str) -> str | None:
|
||||||
|
"""Return the leftmost ``for=`` client IP from an RFC 7239 Forwarded header."""
|
||||||
|
match = _FORWARDED_FOR_RE.search(header)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
value = _strip_port(match.group(1))
|
||||||
|
return value or None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_from_x_forwarded_for(header: str) -> str | None:
|
||||||
|
"""Return the leftmost entry from an X-Forwarded-For header."""
|
||||||
|
first = header.split(",", 1)[0].strip()
|
||||||
|
return first or None
|
||||||
|
|
||||||
|
|
||||||
|
def get_client_ip(request: Request) -> str:
|
||||||
|
"""Extract the real client IP, respecting proxy headers when trusted.
|
||||||
|
|
||||||
|
Falls back to ``request.client.host`` when no trusted proxies are
|
||||||
|
configured or the direct peer is not trusted.
|
||||||
|
"""
|
||||||
|
peer_ip = request.client.host if request.client else "unknown"
|
||||||
|
|
||||||
|
if not _trusted_networks:
|
||||||
|
return peer_ip
|
||||||
|
|
||||||
|
if not _is_trusted(peer_ip, _trusted_networks):
|
||||||
|
return peer_ip
|
||||||
|
|
||||||
|
# Prefer standardized Forwarded header (RFC 7239).
|
||||||
|
forwarded = request.headers.get("forwarded")
|
||||||
|
if forwarded:
|
||||||
|
client_ip = _extract_from_forwarded(forwarded)
|
||||||
|
if client_ip:
|
||||||
|
return client_ip
|
||||||
|
|
||||||
|
# Fall back to de-facto X-Forwarded-For.
|
||||||
|
xff = request.headers.get("x-forwarded-for")
|
||||||
|
if xff:
|
||||||
|
client_ip = _extract_from_x_forwarded_for(xff)
|
||||||
|
if client_ip:
|
||||||
|
return client_ip
|
||||||
|
|
||||||
|
return peer_ip
|
||||||
|
|
||||||
|
|
||||||
|
def _load_trusted_networks() -> list[IPv4Network | IPv6Network]:
|
||||||
|
"""Load trusted proxy networks from settings (called once at import)."""
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
return _parse_trusted_networks(settings.trusted_proxies)
|
||||||
|
|
||||||
|
|
||||||
|
_trusted_networks: list[IPv4Network | IPv6Network] = _load_trusted_networks()
|
||||||
@@ -65,6 +65,11 @@ class Settings(BaseSettings):
|
|||||||
rate_limit_backend: RateLimitBackend = RateLimitBackend.MEMORY
|
rate_limit_backend: RateLimitBackend = RateLimitBackend.MEMORY
|
||||||
rate_limit_redis_url: str = ""
|
rate_limit_redis_url: str = ""
|
||||||
|
|
||||||
|
# Trusted reverse-proxy IPs/CIDRs for client-IP extraction from
|
||||||
|
# Forwarded / X-Forwarded-For headers. Comma-separated.
|
||||||
|
# Leave empty to always use the direct peer address.
|
||||||
|
trusted_proxies: str = ""
|
||||||
|
|
||||||
# Database lifecycle
|
# Database lifecycle
|
||||||
db_auto_migrate: bool = False
|
db_auto_migrate: bool = False
|
||||||
|
|
||||||
|
|||||||
179
backend/tests/test_client_ip.py
Normal file
179
backend/tests/test_client_ip.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
"""Tests for trusted client-IP extraction."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.client_ip import (
|
||||||
|
_extract_from_forwarded,
|
||||||
|
_extract_from_x_forwarded_for,
|
||||||
|
_parse_trusted_networks,
|
||||||
|
_strip_port,
|
||||||
|
get_client_ip,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeClient:
|
||||||
|
def __init__(self, host: str) -> None:
|
||||||
|
self.host = host
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeRequest:
|
||||||
|
def __init__(self, peer_ip: str, headers: dict[str, str] | None = None) -> None:
|
||||||
|
self.client = _FakeClient(peer_ip)
|
||||||
|
self._headers = headers or {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def headers(self) -> dict[str, str]:
|
||||||
|
return self._headers
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Unit tests for internal helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_strip_port_ipv4() -> None:
|
||||||
|
assert _strip_port("1.2.3.4:8080") == "1.2.3.4"
|
||||||
|
|
||||||
|
|
||||||
|
def test_strip_port_ipv4_no_port() -> None:
|
||||||
|
assert _strip_port("1.2.3.4") == "1.2.3.4"
|
||||||
|
|
||||||
|
|
||||||
|
def test_strip_port_ipv6_bracketed_with_port() -> None:
|
||||||
|
assert _strip_port("[::1]:8080") == "::1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_strip_port_ipv6_bracketed_no_port() -> None:
|
||||||
|
assert _strip_port("[::1]") == "::1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_forwarded_simple() -> None:
|
||||||
|
assert _extract_from_forwarded("for=192.0.2.60") == "192.0.2.60"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_forwarded_quoted_with_port() -> None:
|
||||||
|
assert _extract_from_forwarded('for="192.0.2.60:8080"') == "192.0.2.60"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_forwarded_ipv6() -> None:
|
||||||
|
assert _extract_from_forwarded('for="[2001:db8::1]"') == "2001:db8::1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_forwarded_multiple_takes_first() -> None:
|
||||||
|
assert _extract_from_forwarded("for=203.0.113.50, for=198.51.100.1") == "203.0.113.50"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_forwarded_with_other_directives() -> None:
|
||||||
|
assert _extract_from_forwarded("for=192.0.2.43;proto=https;by=203.0.113.60") == "192.0.2.43"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_forwarded_empty() -> None:
|
||||||
|
assert _extract_from_forwarded("proto=https") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_xff_simple() -> None:
|
||||||
|
assert _extract_from_x_forwarded_for("203.0.113.50") == "203.0.113.50"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_xff_multiple_takes_first() -> None:
|
||||||
|
assert _extract_from_x_forwarded_for("203.0.113.50, 198.51.100.1, 10.0.0.1") == "203.0.113.50"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_xff_empty() -> None:
|
||||||
|
assert _extract_from_x_forwarded_for("") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_trusted_networks_valid() -> None:
|
||||||
|
nets = _parse_trusted_networks("127.0.0.1, 10.0.0.0/8, ::1")
|
||||||
|
assert len(nets) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_trusted_networks_empty() -> None:
|
||||||
|
assert _parse_trusted_networks("") == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_trusted_networks_ignores_invalid() -> None:
|
||||||
|
nets = _parse_trusted_networks("127.0.0.1, not-an-ip, 10.0.0.0/8")
|
||||||
|
assert len(nets) == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration tests for get_client_ip
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_returns_peer_ip_when_no_trusted_proxies() -> None:
|
||||||
|
req = _FakeRequest("10.0.0.1", {"x-forwarded-for": "203.0.113.50"})
|
||||||
|
with patch("app.core.client_ip._trusted_networks", []):
|
||||||
|
assert get_client_ip(req) == "10.0.0.1" # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def test_returns_peer_ip_when_peer_not_trusted() -> None:
|
||||||
|
nets = _parse_trusted_networks("172.16.0.0/12")
|
||||||
|
req = _FakeRequest("10.0.0.1", {"x-forwarded-for": "203.0.113.50"})
|
||||||
|
with patch("app.core.client_ip._trusted_networks", nets):
|
||||||
|
assert get_client_ip(req) == "10.0.0.1" # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def test_extracts_from_x_forwarded_for() -> None:
|
||||||
|
nets = _parse_trusted_networks("10.0.0.1")
|
||||||
|
req = _FakeRequest("10.0.0.1", {"x-forwarded-for": "203.0.113.50, 10.0.0.1"})
|
||||||
|
with patch("app.core.client_ip._trusted_networks", nets):
|
||||||
|
assert get_client_ip(req) == "203.0.113.50" # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def test_extracts_from_forwarded_header() -> None:
|
||||||
|
nets = _parse_trusted_networks("10.0.0.1")
|
||||||
|
req = _FakeRequest("10.0.0.1", {"forwarded": "for=203.0.113.50;proto=https"})
|
||||||
|
with patch("app.core.client_ip._trusted_networks", nets):
|
||||||
|
assert get_client_ip(req) == "203.0.113.50" # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def test_forwarded_takes_precedence_over_xff() -> None:
|
||||||
|
nets = _parse_trusted_networks("10.0.0.1")
|
||||||
|
req = _FakeRequest(
|
||||||
|
"10.0.0.1",
|
||||||
|
{
|
||||||
|
"forwarded": "for=198.51.100.1",
|
||||||
|
"x-forwarded-for": "203.0.113.50",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with patch("app.core.client_ip._trusted_networks", nets):
|
||||||
|
assert get_client_ip(req) == "198.51.100.1" # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def test_returns_peer_when_headers_empty() -> None:
|
||||||
|
nets = _parse_trusted_networks("10.0.0.1")
|
||||||
|
req = _FakeRequest("10.0.0.1", {})
|
||||||
|
with patch("app.core.client_ip._trusted_networks", nets):
|
||||||
|
assert get_client_ip(req) == "10.0.0.1" # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def test_cidr_matching() -> None:
|
||||||
|
nets = _parse_trusted_networks("10.0.0.0/8")
|
||||||
|
req = _FakeRequest("10.255.0.1", {"x-forwarded-for": "203.0.113.50"})
|
||||||
|
with patch("app.core.client_ip._trusted_networks", nets):
|
||||||
|
assert get_client_ip(req) == "203.0.113.50" # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def test_strips_port_from_forwarded() -> None:
|
||||||
|
nets = _parse_trusted_networks("10.0.0.1")
|
||||||
|
req = _FakeRequest("10.0.0.1", {"forwarded": 'for="192.0.2.60:8080"'})
|
||||||
|
with patch("app.core.client_ip._trusted_networks", nets):
|
||||||
|
assert get_client_ip(req) == "192.0.2.60" # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def test_strips_port_from_forwarded_ipv6() -> None:
|
||||||
|
nets = _parse_trusted_networks("10.0.0.1")
|
||||||
|
req = _FakeRequest("10.0.0.1", {"forwarded": 'for="[2001:db8::1]:9090"'})
|
||||||
|
with patch("app.core.client_ip._trusted_networks", nets):
|
||||||
|
assert get_client_ip(req) == "2001:db8::1" # type: ignore[arg-type]
|
||||||
Reference in New Issue
Block a user