feat: implement error handling middleware with request ID tracking for improved error responses
This commit is contained in:
143
backend/app/core/error_handling.py
Normal file
143
backend/app/core/error_handling.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, Final, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError, ResponseValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from starlette.responses import Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REQUEST_ID_HEADER: Final[str] = "X-Request-Id"
|
||||
|
||||
ExceptionHandler = Callable[[Request, Exception], Response | Awaitable[Response]]
|
||||
|
||||
|
||||
class RequestIdMiddleware:
|
||||
def __init__(self, app: ASGIApp, *, header_name: str = REQUEST_ID_HEADER) -> None:
|
||||
self._app = app
|
||||
self._header_name = header_name
|
||||
self._header_name_bytes = header_name.lower().encode("latin-1")
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self._app(scope, receive, send)
|
||||
return
|
||||
|
||||
request_id = self._get_or_create_request_id(scope)
|
||||
|
||||
async def send_with_request_id(message: Message) -> None:
|
||||
if message["type"] == "http.response.start":
|
||||
# Starlette uses `list[tuple[bytes, bytes]]` here.
|
||||
headers: list[tuple[bytes, bytes]] = message.setdefault("headers", [])
|
||||
if not any(key.lower() == self._header_name_bytes for key, _ in headers):
|
||||
headers.append((self._header_name_bytes, request_id.encode("latin-1")))
|
||||
await send(message)
|
||||
|
||||
await self._app(scope, receive, send_with_request_id)
|
||||
|
||||
def _get_or_create_request_id(self, scope: Scope) -> str:
|
||||
# Accept a client-provided request id if present.
|
||||
request_id: str | None = None
|
||||
for key, value in scope.get("headers", []):
|
||||
if key.lower() == self._header_name_bytes:
|
||||
candidate = value.decode("latin-1").strip()
|
||||
if candidate:
|
||||
request_id = candidate
|
||||
break
|
||||
|
||||
if request_id is None:
|
||||
request_id = uuid4().hex
|
||||
|
||||
# `Request.state` is backed by `scope["state"]`.
|
||||
state = scope.setdefault("state", {})
|
||||
state["request_id"] = request_id
|
||||
return request_id
|
||||
|
||||
|
||||
def install_error_handling(app: FastAPI) -> None:
|
||||
# Important: add request-id middleware last so it's the outermost middleware.
|
||||
# This ensures it still runs even if another middleware (e.g. CORS preflight) returns early.
|
||||
app.add_middleware(RequestIdMiddleware)
|
||||
|
||||
app.add_exception_handler(
|
||||
RequestValidationError,
|
||||
cast(ExceptionHandler, _request_validation_handler),
|
||||
)
|
||||
app.add_exception_handler(
|
||||
ResponseValidationError,
|
||||
cast(ExceptionHandler, _response_validation_handler),
|
||||
)
|
||||
app.add_exception_handler(
|
||||
StarletteHTTPException,
|
||||
cast(ExceptionHandler, _http_exception_handler),
|
||||
)
|
||||
app.add_exception_handler(Exception, _unhandled_exception_handler)
|
||||
|
||||
|
||||
def _get_request_id(request: Request) -> str | None:
|
||||
request_id = getattr(request.state, "request_id", None)
|
||||
if isinstance(request_id, str) and request_id:
|
||||
return request_id
|
||||
return None
|
||||
|
||||
|
||||
def _error_payload(*, detail: Any, request_id: str | None) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {"detail": detail}
|
||||
if request_id:
|
||||
payload["request_id"] = request_id
|
||||
return payload
|
||||
|
||||
|
||||
async def _request_validation_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
||||
# `RequestValidationError` is expected user input; don't log at ERROR.
|
||||
request_id = _get_request_id(request)
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content=_error_payload(detail=exc.errors(), request_id=request_id),
|
||||
)
|
||||
|
||||
|
||||
async def _response_validation_handler(request: Request, exc: ResponseValidationError) -> JSONResponse:
|
||||
request_id = _get_request_id(request)
|
||||
logger.exception(
|
||||
"response_validation_error",
|
||||
extra={
|
||||
"request_id": request_id,
|
||||
"method": request.method,
|
||||
"path": request.url.path,
|
||||
"errors": exc.errors(),
|
||||
},
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=_error_payload(detail="Internal Server Error", request_id=request_id),
|
||||
)
|
||||
|
||||
|
||||
async def _http_exception_handler(request: Request, exc: StarletteHTTPException) -> JSONResponse:
|
||||
request_id = _get_request_id(request)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=_error_payload(detail=exc.detail, request_id=request_id),
|
||||
headers=exc.headers,
|
||||
)
|
||||
|
||||
|
||||
async def _unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
request_id = _get_request_id(request)
|
||||
logger.exception(
|
||||
"unhandled_exception",
|
||||
extra={"request_id": request_id, "method": request.method, "path": request.url.path},
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=_error_payload(detail="Internal Server Error", request_id=request_id),
|
||||
headers={REQUEST_ID_HEADER: request_id} if request_id else None,
|
||||
)
|
||||
Reference in New Issue
Block a user