feat: add trusted client-IP extraction from proxy headers

Add get_client_ip() helper that inspects Forwarded and X-Forwarded-For
headers only when the direct peer is in TRUSTED_PROXIES (comma-separated
IPs/CIDRs). Replaces raw request.client.host in rate-limit and webhook
source_ip to prevent all traffic collapsing behind a reverse proxy IP.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Hugh Brown
2026-03-04 11:51:02 -07:00
committed by Abhimanyu Saharan
parent 24e40f1153
commit f1bcf72810
5 changed files with 312 additions and 5 deletions

View File

@@ -23,6 +23,7 @@ from fastapi import Depends, Header, HTTPException, Request, status
from sqlmodel import col, select
from app.core.agent_tokens import verify_agent_token
from app.core.client_ip import get_client_ip
from app.core.logging import get_logger
from app.core.rate_limit import agent_auth_limiter
from app.core.time import utcnow
@@ -113,7 +114,7 @@ async def get_agent_auth_context(
session: AsyncSession = SESSION_DEP,
) -> AgentAuthContext:
"""Require and validate agent auth token from request headers."""
client_ip = request.client.host if request.client else "unknown"
client_ip = get_client_ip(request)
if not agent_auth_limiter.is_allowed(client_ip):
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS)
resolved = _resolve_agent_token(
@@ -174,7 +175,7 @@ async def get_agent_auth_context_optional(
# guessing via the optional auth path. Scoped to X-Agent-Token so that
# normal user Authorization headers are not throttled.
if agent_token:
client_ip = request.client.host if request.client else "unknown"
client_ip = get_client_ip(request)
if not agent_auth_limiter.is_allowed(client_ip):
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS)
agent = await _find_agent_for_token(session, resolved)

View File

@@ -0,0 +1,121 @@
"""Trusted client-IP extraction from proxy headers.
In proxied deployments ``request.client.host`` returns the reverse-proxy
IP. This module provides :func:`get_client_ip` which inspects the
``Forwarded`` and ``X-Forwarded-For`` headers **only** when the direct
peer is in the configured ``TRUSTED_PROXIES`` list.
"""
from __future__ import annotations
import ipaddress
import re
from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network
from typing import TYPE_CHECKING
from app.core.logging import get_logger
if TYPE_CHECKING:
from fastapi import Request
logger = get_logger(__name__)
# RFC 7239 ``for=`` directive value. Handles quoted/unquoted, optional
# port, and IPv6 bracket notation.
_FORWARDED_FOR_RE = re.compile(r'for="?(\[[^\]]+\]|[^";,\s]+)', re.IGNORECASE)
def _parse_trusted_networks(raw: str) -> list[IPv4Network | IPv6Network]:
"""Parse a comma-separated list of IPs/CIDRs into network objects."""
networks: list[IPv4Network | IPv6Network] = []
for entry in raw.split(","):
entry = entry.strip()
if not entry:
continue
try:
networks.append(ipaddress.ip_network(entry, strict=False))
except ValueError:
logger.warning("trusted_proxies: ignoring invalid entry %r", entry)
return networks
def _is_trusted(peer_ip: str, networks: list[IPv4Network | IPv6Network]) -> bool:
"""Check whether *peer_ip* falls within any trusted network."""
try:
addr = ipaddress.ip_address(peer_ip)
except ValueError:
return False
return any(addr in net for net in networks)
def _strip_port(raw: str) -> str:
"""Strip port suffix from a ``Forwarded: for=`` value.
Handles ``1.2.3.4:8080``, ``[::1]:8080``, and bare addresses.
"""
# Bracketed IPv6: ``[::1]:port`` or ``[::1]``
if raw.startswith("["):
bracket_end = raw.find("]")
if bracket_end != -1:
return raw[1:bracket_end]
return raw.strip("[]")
# IPv4 with port: ``1.2.3.4:8080``
if "." in raw and ":" in raw:
return raw.rsplit(":", 1)[0]
return raw
def _extract_from_forwarded(header: str) -> str | None:
"""Return the leftmost ``for=`` client IP from an RFC 7239 Forwarded header."""
match = _FORWARDED_FOR_RE.search(header)
if not match:
return None
value = _strip_port(match.group(1))
return value or None
def _extract_from_x_forwarded_for(header: str) -> str | None:
"""Return the leftmost entry from an X-Forwarded-For header."""
first = header.split(",", 1)[0].strip()
return first or None
def get_client_ip(request: Request) -> str:
"""Extract the real client IP, respecting proxy headers when trusted.
Falls back to ``request.client.host`` when no trusted proxies are
configured or the direct peer is not trusted.
"""
peer_ip = request.client.host if request.client else "unknown"
if not _trusted_networks:
return peer_ip
if not _is_trusted(peer_ip, _trusted_networks):
return peer_ip
# Prefer standardized Forwarded header (RFC 7239).
forwarded = request.headers.get("forwarded")
if forwarded:
client_ip = _extract_from_forwarded(forwarded)
if client_ip:
return client_ip
# Fall back to de-facto X-Forwarded-For.
xff = request.headers.get("x-forwarded-for")
if xff:
client_ip = _extract_from_x_forwarded_for(xff)
if client_ip:
return client_ip
return peer_ip
def _load_trusted_networks() -> list[IPv4Network | IPv6Network]:
"""Load trusted proxy networks from settings (called once at import)."""
from app.core.config import settings
return _parse_trusted_networks(settings.trusted_proxies)
_trusted_networks: list[IPv4Network | IPv6Network] = _load_trusted_networks()

View File

@@ -65,6 +65,11 @@ class Settings(BaseSettings):
rate_limit_backend: RateLimitBackend = RateLimitBackend.MEMORY
rate_limit_redis_url: str = ""
# Trusted reverse-proxy IPs/CIDRs for client-IP extraction from
# Forwarded / X-Forwarded-For headers. Comma-separated.
# Leave empty to always use the direct peer address.
trusted_proxies: str = ""
# Database lifecycle
db_auto_migrate: bool = False