230 lines
7.7 KiB
Python
230 lines
7.7 KiB
Python
# ruff: noqa
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
import pytest
|
|
|
|
from app.core import error_handling as error_handling_module
|
|
from app.core.error_handling import REQUEST_ID_HEADER, RequestIdMiddleware
|
|
from app.core.logging import TRACE_LEVEL, AppLogFilter, get_logger
|
|
from app.core.version import APP_NAME, APP_VERSION
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_request_id_middleware_passes_through_non_http_scope() -> None:
|
|
called = False
|
|
|
|
async def app(scope, receive, send): # type: ignore[no-untyped-def]
|
|
nonlocal called
|
|
called = True
|
|
|
|
middleware = RequestIdMiddleware(app)
|
|
|
|
request_scope = {"type": "websocket", "headers": []}
|
|
await middleware(request_scope, lambda: None, lambda message: None) # type: ignore[arg-type]
|
|
|
|
assert called is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_request_id_middleware_ignores_blank_client_header_and_generates_one() -> None:
|
|
captured_request_id: str | None = None
|
|
response_headers: list[tuple[bytes, bytes]] = []
|
|
|
|
async def app(scope, receive, send): # type: ignore[no-untyped-def]
|
|
nonlocal captured_request_id
|
|
captured_request_id = scope.get("state", {}).get("request_id")
|
|
await send({"type": "http.response.start", "status": 200, "headers": []})
|
|
await send({"type": "http.response.body", "body": b"ok"})
|
|
|
|
async def send(message): # type: ignore[no-untyped-def]
|
|
if message["type"] == "http.response.start":
|
|
response_headers.extend(list(message.get("headers") or []))
|
|
|
|
middleware = RequestIdMiddleware(app)
|
|
|
|
request_scope = {
|
|
"type": "http",
|
|
"headers": [(REQUEST_ID_HEADER.lower().encode("latin-1"), b" ")],
|
|
}
|
|
await middleware(request_scope, lambda: None, send)
|
|
|
|
assert isinstance(captured_request_id, str) and captured_request_id
|
|
# Header should reflect the generated id, not the blank one.
|
|
values = [
|
|
v for k, v in response_headers if k.lower() == REQUEST_ID_HEADER.lower().encode("latin-1")
|
|
]
|
|
assert values == [captured_request_id.encode("latin-1")]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_request_id_middleware_does_not_duplicate_existing_header() -> None:
|
|
sent_start = False
|
|
start_headers: list[tuple[bytes, bytes]] | None = None
|
|
|
|
async def app(scope, receive, send): # type: ignore[no-untyped-def]
|
|
# Simulate an app that already sets the request id header.
|
|
await send(
|
|
{
|
|
"type": "http.response.start",
|
|
"status": 200,
|
|
"headers": [(REQUEST_ID_HEADER.lower().encode("latin-1"), b"already")],
|
|
}
|
|
)
|
|
await send({"type": "http.response.body", "body": b"ok"})
|
|
|
|
async def send(message): # type: ignore[no-untyped-def]
|
|
nonlocal sent_start, start_headers
|
|
if message["type"] == "http.response.start":
|
|
sent_start = True
|
|
start_headers = list(message.get("headers") or [])
|
|
|
|
middleware = RequestIdMiddleware(app)
|
|
|
|
request_scope = {"type": "http", "headers": []}
|
|
await middleware(request_scope, lambda: None, send)
|
|
|
|
assert sent_start is True
|
|
assert start_headers is not None
|
|
|
|
# Ensure the middleware did not append a second copy.
|
|
values = [
|
|
v for k, v in start_headers if k.lower() == REQUEST_ID_HEADER.lower().encode("latin-1")
|
|
]
|
|
assert values == [b"already"]
|
|
|
|
|
|
class _CaptureHandler(logging.Handler):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.records: list[logging.LogRecord] = []
|
|
|
|
def emit(self, record: logging.LogRecord) -> None:
|
|
self.records.append(record)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_request_id_middleware_logs_trace_start_and_debug_completion() -> None:
|
|
capture = _CaptureHandler()
|
|
capture.setLevel(TRACE_LEVEL)
|
|
logger = error_handling_module.logger
|
|
logger.setLevel(TRACE_LEVEL)
|
|
logger.addHandler(capture)
|
|
|
|
async def app(scope, receive, send): # type: ignore[no-untyped-def]
|
|
await send({"type": "http.response.start", "status": 200, "headers": []})
|
|
await send({"type": "http.response.body", "body": b"ok"})
|
|
|
|
middleware = RequestIdMiddleware(app)
|
|
request_scope = {
|
|
"type": "http",
|
|
"method": "GET",
|
|
"path": "/api/v1/auth/bootstrap",
|
|
"client": ("127.0.0.1", 5454),
|
|
"headers": [],
|
|
}
|
|
sent_messages: list[dict[str, object]] = []
|
|
|
|
async def send(message): # type: ignore[no-untyped-def]
|
|
sent_messages.append(message)
|
|
|
|
try:
|
|
await middleware(request_scope, lambda: None, send)
|
|
finally:
|
|
logger.removeHandler(capture)
|
|
capture.close()
|
|
|
|
start = next(
|
|
record for record in capture.records if record.getMessage() == "http.request.start"
|
|
)
|
|
complete = next(
|
|
record for record in capture.records if record.getMessage() == "http.request.complete"
|
|
)
|
|
|
|
assert start.levelname == "TRACE"
|
|
assert getattr(start, "method", None) == "GET"
|
|
assert getattr(start, "path", None) == "/api/v1/auth/bootstrap"
|
|
|
|
assert complete.levelname == "DEBUG"
|
|
assert getattr(complete, "status_code", None) == 200
|
|
assert isinstance(getattr(complete, "duration_ms", None), int)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_request_id_middleware_logs_error_for_5xx_completion() -> None:
|
|
capture = _CaptureHandler()
|
|
capture.setLevel(TRACE_LEVEL)
|
|
logger = error_handling_module.logger
|
|
logger.setLevel(TRACE_LEVEL)
|
|
logger.addHandler(capture)
|
|
|
|
async def app(scope, receive, send): # type: ignore[no-untyped-def]
|
|
await send({"type": "http.response.start", "status": 503, "headers": []})
|
|
await send({"type": "http.response.body", "body": b"unavailable"})
|
|
|
|
middleware = RequestIdMiddleware(app)
|
|
request_scope = {
|
|
"type": "http",
|
|
"method": "POST",
|
|
"path": "/api/v1/tasks",
|
|
"client": ("127.0.0.1", 5454),
|
|
"headers": [],
|
|
}
|
|
sent_messages: list[dict[str, object]] = []
|
|
|
|
async def send(message): # type: ignore[no-untyped-def]
|
|
sent_messages.append(message)
|
|
|
|
try:
|
|
await middleware(request_scope, lambda: None, send)
|
|
finally:
|
|
logger.removeHandler(capture)
|
|
capture.close()
|
|
|
|
complete = next(
|
|
record for record in capture.records if record.getMessage() == "http.request.complete"
|
|
)
|
|
assert complete.levelname == "ERROR"
|
|
assert getattr(complete, "status_code", None) == 503
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_request_id_middleware_enriches_in_request_logs_with_route_context() -> None:
|
|
capture = _CaptureHandler()
|
|
capture.setLevel(TRACE_LEVEL)
|
|
capture.addFilter(AppLogFilter(APP_NAME, APP_VERSION))
|
|
|
|
app_logger = get_logger("tests.request_context.enrichment")
|
|
app_logger.setLevel(TRACE_LEVEL)
|
|
app_logger.addHandler(capture)
|
|
|
|
async def app(scope, receive, send): # type: ignore[no-untyped-def]
|
|
app_logger.info("inside.request.handler")
|
|
await send({"type": "http.response.start", "status": 200, "headers": []})
|
|
await send({"type": "http.response.body", "body": b"ok"})
|
|
|
|
middleware = RequestIdMiddleware(app)
|
|
request_scope = {
|
|
"type": "http",
|
|
"method": "PUT",
|
|
"path": "/api/v1/boards/abc",
|
|
"client": ("127.0.0.1", 5454),
|
|
"headers": [],
|
|
}
|
|
|
|
async def send(_message): # type: ignore[no-untyped-def]
|
|
return None
|
|
|
|
try:
|
|
await middleware(request_scope, lambda: None, send)
|
|
finally:
|
|
app_logger.removeHandler(capture)
|
|
capture.close()
|
|
|
|
record = next(item for item in capture.records if item.getMessage() == "inside.request.handler")
|
|
assert isinstance(getattr(record, "request_id", None), str) and getattr(record, "request_id")
|
|
assert getattr(record, "method", None) == "PUT"
|
|
assert getattr(record, "path", None) == "/api/v1/boards/abc"
|