295 lines
10 KiB
Python
295 lines
10 KiB
Python
"""Global exception handlers and request-id middleware for FastAPI.
|
|
|
|
This module standardizes two operational behaviors:
|
|
|
|
1) **Request IDs**
|
|
- Every response includes an `X-Request-Id` header.
|
|
- Clients may supply their own request id; otherwise we generate one.
|
|
- The request id is propagated into logs via context vars.
|
|
|
|
2) **Error responses**
|
|
- Errors are returned as JSON with a stable top-level shape:
|
|
`{ "detail": ..., "request_id": ... }`
|
|
- Validation errors (`422`) return structured field errors.
|
|
- Unhandled errors are logged at ERROR and return a generic 500.
|
|
|
|
Design notes:
|
|
- The request-id middleware is installed *outermost* so it runs even when other
|
|
middleware returns early.
|
|
- Health endpoints are excluded from request logs by default to reduce noise.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Awaitable, Callable
|
|
from time import perf_counter
|
|
from typing import TYPE_CHECKING, Any, Final
|
|
from uuid import uuid4
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.exceptions import RequestValidationError, ResponseValidationError
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
from starlette.responses import Response
|
|
|
|
from app.core.config import settings
|
|
from app.core.logging import (
|
|
TRACE_LEVEL,
|
|
get_logger,
|
|
reset_request_id,
|
|
reset_request_route_context,
|
|
set_request_id,
|
|
set_request_route_context,
|
|
)
|
|
|
|
if TYPE_CHECKING: # pragma: no cover
|
|
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
REQUEST_ID_HEADER: Final[str] = "X-Request-Id"
|
|
_HEALTH_CHECK_PATHS: Final[frozenset[str]] = frozenset({"/health", "/healthz", "/readyz"})
|
|
|
|
ExceptionHandler = Callable[[Request, Exception], Response | Awaitable[Response]]
|
|
|
|
|
|
class RequestIdMiddleware:
|
|
"""ASGI middleware that ensures every request has a request-id."""
|
|
|
|
def __init__(self, app: ASGIApp, *, header_name: str = REQUEST_ID_HEADER) -> None:
|
|
"""Initialize middleware with app instance and header name."""
|
|
self._app = app
|
|
self._header_name = header_name
|
|
self._header_name_bytes = header_name.lower().encode("latin-1")
|
|
self._slow_request_ms = settings.request_log_slow_ms
|
|
self._include_health_logs = settings.request_log_include_health
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
"""Inject request-id into request state and response headers."""
|
|
if scope["type"] != "http":
|
|
await self._app(scope, receive, send)
|
|
return
|
|
|
|
method = str(scope.get("method") or "UNKNOWN").upper()
|
|
path = str(scope.get("path") or "")
|
|
client = scope.get("client")
|
|
client_ip: str | None = None
|
|
if isinstance(client, tuple) and client and isinstance(client[0], str):
|
|
client_ip = client[0]
|
|
should_log = self._include_health_logs or path not in _HEALTH_CHECK_PATHS
|
|
started_at = perf_counter()
|
|
status_code: int | None = None
|
|
|
|
request_id = self._get_or_create_request_id(scope)
|
|
context_token = set_request_id(request_id)
|
|
route_context_tokens = set_request_route_context(method, path)
|
|
if should_log:
|
|
logger.log(
|
|
TRACE_LEVEL,
|
|
"http.request.start",
|
|
extra={
|
|
"method": method,
|
|
"path": path,
|
|
"client_ip": client_ip,
|
|
},
|
|
)
|
|
|
|
async def send_with_request_id(message: Message) -> None:
|
|
nonlocal status_code
|
|
if message["type"] == "http.response.start":
|
|
# Starlette uses `list[tuple[bytes, bytes]]` here.
|
|
headers: list[tuple[bytes, bytes]] = message.setdefault("headers", [])
|
|
if not any(key.lower() == self._header_name_bytes for key, _ in headers):
|
|
request_id_bytes = request_id.encode("latin-1")
|
|
headers.append((self._header_name_bytes, request_id_bytes))
|
|
status = message.get("status")
|
|
status_code = status if isinstance(status, int) else 500
|
|
if should_log:
|
|
duration_ms = int((perf_counter() - started_at) * 1000)
|
|
extra = {
|
|
"method": method,
|
|
"path": path,
|
|
"status_code": status_code,
|
|
"duration_ms": duration_ms,
|
|
"client_ip": client_ip,
|
|
}
|
|
if status_code >= 500:
|
|
logger.error("http.request.complete", extra=extra)
|
|
elif status_code >= 400:
|
|
logger.warning("http.request.complete", extra=extra)
|
|
else:
|
|
logger.debug("http.request.complete", extra=extra)
|
|
if self._slow_request_ms and duration_ms >= self._slow_request_ms:
|
|
logger.warning(
|
|
"http.request.slow",
|
|
extra={
|
|
**extra,
|
|
"slow_threshold_ms": self._slow_request_ms,
|
|
},
|
|
)
|
|
await send(message)
|
|
|
|
try:
|
|
await self._app(scope, receive, send_with_request_id)
|
|
finally:
|
|
if should_log and status_code is None:
|
|
logger.warning(
|
|
"http.request.incomplete",
|
|
extra={
|
|
"method": method,
|
|
"path": path,
|
|
"duration_ms": int((perf_counter() - started_at) * 1000),
|
|
"client_ip": client_ip,
|
|
},
|
|
)
|
|
reset_request_route_context(route_context_tokens)
|
|
reset_request_id(context_token)
|
|
|
|
def _get_or_create_request_id(self, scope: Scope) -> str:
|
|
# Accept a client-provided request id if present.
|
|
request_id: str | None = None
|
|
for key, value in scope.get("headers", []):
|
|
if key.lower() == self._header_name_bytes:
|
|
candidate = value.decode("latin-1").strip()
|
|
if candidate:
|
|
request_id = candidate
|
|
break
|
|
|
|
if request_id is None:
|
|
request_id = uuid4().hex
|
|
|
|
# `Request.state` is backed by `scope["state"]`.
|
|
state = scope.setdefault("state", {})
|
|
state["request_id"] = request_id
|
|
return request_id
|
|
|
|
|
|
def install_error_handling(app: FastAPI) -> None:
|
|
"""Install middleware and exception handlers on the FastAPI app."""
|
|
# Important: add request-id middleware last so it's the outermost middleware.
|
|
# This ensures it still runs even if another middleware
|
|
# (e.g. CORS preflight) returns early.
|
|
app.add_middleware(RequestIdMiddleware)
|
|
|
|
app.add_exception_handler(
|
|
RequestValidationError,
|
|
_request_validation_exception_handler,
|
|
)
|
|
app.add_exception_handler(
|
|
ResponseValidationError,
|
|
_response_validation_exception_handler,
|
|
)
|
|
app.add_exception_handler(
|
|
StarletteHTTPException,
|
|
_http_exception_exception_handler,
|
|
)
|
|
app.add_exception_handler(Exception, _unhandled_exception_handler)
|
|
|
|
|
|
async def _request_validation_exception_handler(
|
|
request: Request,
|
|
exc: Exception,
|
|
) -> Response:
|
|
if not isinstance(exc, RequestValidationError):
|
|
msg = "Expected RequestValidationError"
|
|
raise TypeError(msg)
|
|
return await _request_validation_handler(request, exc)
|
|
|
|
|
|
async def _response_validation_exception_handler(
|
|
request: Request,
|
|
exc: Exception,
|
|
) -> Response:
|
|
if not isinstance(exc, ResponseValidationError):
|
|
msg = "Expected ResponseValidationError"
|
|
raise TypeError(msg)
|
|
return await _response_validation_handler(request, exc)
|
|
|
|
|
|
async def _http_exception_exception_handler(
|
|
request: Request,
|
|
exc: Exception,
|
|
) -> Response:
|
|
if not isinstance(exc, StarletteHTTPException):
|
|
msg = "Expected StarletteHTTPException"
|
|
raise TypeError(msg)
|
|
return await _http_exception_handler(request, exc)
|
|
|
|
|
|
def _get_request_id(request: Request) -> str | None:
|
|
request_id = getattr(request.state, "request_id", None)
|
|
if isinstance(request_id, str) and request_id:
|
|
return request_id
|
|
return None
|
|
|
|
|
|
def _error_payload(*, detail: object, request_id: str | None) -> dict[str, object]:
|
|
payload: dict[str, Any] = {"detail": detail}
|
|
if request_id:
|
|
payload["request_id"] = request_id
|
|
return payload
|
|
|
|
|
|
async def _request_validation_handler(
|
|
request: Request,
|
|
exc: RequestValidationError,
|
|
) -> JSONResponse:
|
|
# `RequestValidationError` is expected user input; don't log at ERROR.
|
|
request_id = _get_request_id(request)
|
|
return JSONResponse(
|
|
status_code=422,
|
|
content=_error_payload(detail=exc.errors(), request_id=request_id),
|
|
)
|
|
|
|
|
|
async def _response_validation_handler(
|
|
request: Request,
|
|
exc: ResponseValidationError,
|
|
) -> JSONResponse:
|
|
request_id = _get_request_id(request)
|
|
logger.exception(
|
|
"response_validation_error",
|
|
extra={
|
|
"request_id": request_id,
|
|
"method": request.method,
|
|
"path": request.url.path,
|
|
"errors": exc.errors(),
|
|
},
|
|
)
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content=_error_payload(detail="Internal Server Error", request_id=request_id),
|
|
)
|
|
|
|
|
|
async def _http_exception_handler(
|
|
request: Request,
|
|
exc: StarletteHTTPException,
|
|
) -> JSONResponse:
|
|
request_id = _get_request_id(request)
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content=_error_payload(detail=exc.detail, request_id=request_id),
|
|
headers=exc.headers,
|
|
)
|
|
|
|
|
|
async def _unhandled_exception_handler(
|
|
request: Request,
|
|
_exc: Exception,
|
|
) -> JSONResponse:
|
|
request_id = _get_request_id(request)
|
|
logger.exception(
|
|
"unhandled_exception",
|
|
extra={
|
|
"request_id": request_id,
|
|
"method": request.method,
|
|
"path": request.url.path,
|
|
},
|
|
)
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content=_error_payload(detail="Internal Server Error", request_id=request_id),
|
|
headers={REQUEST_ID_HEADER: request_id} if request_id else None,
|
|
)
|