refactor: update module docstrings for clarity and consistency

This commit is contained in:
Abhimanyu Saharan
2026-02-09 15:49:50 +05:30
parent 78bb08d4a3
commit 7ca1899d9f
99 changed files with 2345 additions and 855 deletions

View File

@@ -0,0 +1 @@
"""Core utilities and configuration for the backend service."""

View File

@@ -1,33 +1,44 @@
"""Agent authentication helpers for token-backed API access."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import timedelta
from typing import Literal
from typing import TYPE_CHECKING, Literal
from fastapi import Depends, Header, HTTPException, Request, status
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.agent_tokens import verify_agent_token
from app.core.time import utcnow
from app.db.session import get_session
from app.models.agents import Agent
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
logger = logging.getLogger(__name__)
_LAST_SEEN_TOUCH_INTERVAL = timedelta(seconds=30)
_SAFE_METHODS = frozenset({"GET", "HEAD", "OPTIONS"})
SESSION_DEP = Depends(get_session)
@dataclass
class AgentAuthContext:
"""Authenticated actor payload for agent-originated requests."""
actor_type: Literal["agent"]
agent: Agent
async def _find_agent_for_token(session: AsyncSession, token: str) -> Agent | None:
agents = list(await session.exec(select(Agent).where(col(Agent.agent_token_hash).is_not(None))))
agents = list(
await session.exec(
select(Agent).where(col(Agent.agent_token_hash).is_not(None)),
),
)
for agent in agents:
if agent.agent_token_hash and verify_agent_token(token, agent.agent_token_hash):
return agent
@@ -65,9 +76,11 @@ async def _touch_agent_presence(
calls (task comments, memory updates, etc). Touch presence so the UI reflects
real activity even if the heartbeat loop isn't running.
"""
now = utcnow()
if agent.last_seen_at is not None and now - agent.last_seen_at < _LAST_SEEN_TOUCH_INTERVAL:
if (
agent.last_seen_at is not None
and now - agent.last_seen_at < _LAST_SEEN_TOUCH_INTERVAL
):
return
agent.last_seen_at = now
@@ -86,9 +99,14 @@ async def get_agent_auth_context(
request: Request,
agent_token: str | None = Header(default=None, alias="X-Agent-Token"),
authorization: str | None = Header(default=None, alias="Authorization"),
session: AsyncSession = Depends(get_session),
session: AsyncSession = SESSION_DEP,
) -> AgentAuthContext:
resolved = _resolve_agent_token(agent_token, authorization, accept_authorization=True)
"""Require and validate agent auth token from request headers."""
resolved = _resolve_agent_token(
agent_token,
authorization,
accept_authorization=True,
)
if not resolved:
logger.warning(
"agent auth missing token path=%s x_agent=%s authorization=%s",
@@ -113,8 +131,9 @@ async def get_agent_auth_context_optional(
request: Request,
agent_token: str | None = Header(default=None, alias="X-Agent-Token"),
authorization: str | None = Header(default=None, alias="Authorization"),
session: AsyncSession = Depends(get_session),
session: AsyncSession = SESSION_DEP,
) -> AgentAuthContext | None:
"""Optionally resolve agent auth context from `X-Agent-Token` only."""
resolved = _resolve_agent_token(
agent_token,
authorization,

View File

@@ -1,3 +1,5 @@
"""Token generation and verification helpers for agent authentication."""
from __future__ import annotations
import base64
@@ -10,6 +12,7 @@ SALT_BYTES = 16
def generate_agent_token() -> str:
"""Generate a new URL-safe random token for an agent."""
return secrets.token_urlsafe(32)
@@ -23,12 +26,14 @@ def _b64decode(value: str) -> bytes:
def hash_agent_token(token: str) -> str:
"""Hash an agent token using PBKDF2-HMAC-SHA256 with a random salt."""
salt = secrets.token_bytes(SALT_BYTES)
digest = hashlib.pbkdf2_hmac("sha256", token.encode("utf-8"), salt, ITERATIONS)
return f"pbkdf2_sha256${ITERATIONS}${_b64encode(salt)}${_b64encode(digest)}"
def verify_agent_token(token: str, stored_hash: str) -> bool:
"""Verify a plaintext token against a stored PBKDF2 hash representation."""
try:
algorithm, iterations, salt_b64, digest_b64 = stored_hash.split("$")
except ValueError:
@@ -41,5 +46,10 @@ def verify_agent_token(token: str, stored_hash: str) -> bool:
return False
salt = _b64decode(salt_b64)
expected_digest = _b64decode(digest_b64)
candidate = hashlib.pbkdf2_hmac("sha256", token.encode("utf-8"), salt, iterations_int)
candidate = hashlib.pbkdf2_hmac(
"sha256",
token.encode("utf-8"),
salt,
iterations_int,
)
return hmac.compare_digest(candidate, expected_digest)

View File

@@ -1,32 +1,42 @@
"""User authentication helpers backed by Clerk JWT verification."""
from __future__ import annotations
from dataclasses import dataclass
from functools import lru_cache
from typing import Literal
from typing import TYPE_CHECKING, Literal
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi_clerk_auth import ClerkConfig, ClerkHTTPBearer
from fastapi_clerk_auth import HTTPAuthorizationCredentials as ClerkCredentials
from pydantic import BaseModel, ValidationError
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings
from app.db import crud
from app.db.session import get_session
from app.models.users import User
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
security = HTTPBearer(auto_error=False)
SECURITY_DEP = Depends(security)
SESSION_DEP = Depends(get_session)
CLERK_JWKS_URL_REQUIRED_ERROR = "CLERK_JWKS_URL is not set."
class ClerkTokenPayload(BaseModel):
"""JWT claims payload shape required from Clerk tokens."""
sub: str
@lru_cache
def _build_clerk_http_bearer(auto_error: bool) -> ClerkHTTPBearer:
def _build_clerk_http_bearer(*, auto_error: bool) -> ClerkHTTPBearer:
"""Create and cache the Clerk HTTP bearer guard."""
if not settings.clerk_jwks_url:
raise RuntimeError("CLERK_JWKS_URL is not set.")
raise RuntimeError(CLERK_JWKS_URL_REQUIRED_ERROR)
clerk_config = ClerkConfig(
jwks_url=settings.clerk_jwks_url,
verify_iat=settings.clerk_verify_iat,
@@ -37,12 +47,15 @@ def _build_clerk_http_bearer(auto_error: bool) -> ClerkHTTPBearer:
@dataclass
class AuthContext:
"""Authenticated user context resolved from inbound auth headers."""
actor_type: Literal["user"]
user: User | None = None
def _resolve_clerk_auth(
request: Request, fallback: ClerkCredentials | None
request: Request,
fallback: ClerkCredentials | None,
) -> ClerkCredentials | None:
auth_data = getattr(request.state, "clerk_auth", None)
if isinstance(auth_data, ClerkCredentials):
@@ -59,9 +72,10 @@ def _parse_subject(auth_data: ClerkCredentials | None) -> str | None:
async def get_auth_context(
request: Request,
credentials: HTTPAuthorizationCredentials | None = Depends(security),
session: AsyncSession = Depends(get_session),
credentials: HTTPAuthorizationCredentials | None = SECURITY_DEP,
session: AsyncSession = SESSION_DEP,
) -> AuthContext:
"""Resolve required authenticated user context from Clerk JWT headers."""
if credentials is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
@@ -109,9 +123,10 @@ async def get_auth_context(
async def get_auth_context_optional(
request: Request,
credentials: HTTPAuthorizationCredentials | None = Depends(security),
session: AsyncSession = Depends(get_session),
credentials: HTTPAuthorizationCredentials | None = SECURITY_DEP,
session: AsyncSession = SESSION_DEP,
) -> AuthContext | None:
"""Resolve user context if available, otherwise return `None`."""
if request.headers.get("X-Agent-Token"):
return None
if credentials is None:

View File

@@ -1,3 +1,5 @@
"""Application settings and environment configuration loading."""
from __future__ import annotations
from pathlib import Path
@@ -11,6 +13,8 @@ DEFAULT_ENV_FILE = BACKEND_ROOT / ".env"
class Settings(BaseSettings):
"""Typed runtime configuration sourced from environment variables."""
model_config = SettingsConfigDict(
# Load `backend/.env` regardless of current working directory.
# (Important when running uvicorn from repo root or via a process manager.)
@@ -32,8 +36,8 @@ class Settings(BaseSettings):
base_url: str = ""
# Optional: local directory where the backend is allowed to write "preserved" agent
# workspace files (e.g. USER.md/SELF.md/MEMORY.md). If empty, local writes are disabled
# and provisioning relies on the gateway API.
# workspace files (e.g. USER.md/SELF.md/MEMORY.md). If empty, local
# writes are disabled and provisioning relies on the gateway API.
#
# Security note: do NOT point this at arbitrary system paths in production.
local_agent_workspace_root: str = ""
@@ -48,8 +52,8 @@ class Settings(BaseSettings):
@model_validator(mode="after")
def _defaults(self) -> Self:
# In dev, default to applying Alembic migrations at startup to avoid schema drift
# (e.g. missing newly-added columns).
# In dev, default to applying Alembic migrations at startup to avoid
# schema drift (e.g. missing newly-added columns).
if "db_auto_migrate" not in self.model_fields_set and self.environment == "dev":
self.db_auto_migrate = True
return self

View File

@@ -1,8 +1,13 @@
"""Utilities for parsing human-readable duration schedule strings."""
from __future__ import annotations
import re
_DURATION_RE = re.compile(r"^(?P<num>[1-9]\\d*)\\s*(?P<unit>[smhdw])$", flags=re.IGNORECASE)
_DURATION_RE = re.compile(
r"^(?P<num>[1-9]\\d*)\\s*(?P<unit>[smhdw])$",
flags=re.IGNORECASE,
)
_MULTIPLIERS: dict[str, int] = {
"s": 1,
@@ -11,26 +16,36 @@ _MULTIPLIERS: dict[str, int] = {
"d": 60 * 60 * 24,
"w": 60 * 60 * 24 * 7,
}
_MAX_SCHEDULE_SECONDS = 60 * 60 * 24 * 365 * 10
_ERR_SCHEDULE_REQUIRED = "schedule is required"
_ERR_SCHEDULE_INVALID = (
'Invalid schedule. Expected format like "10m", "1h", "2d", "1w".'
)
_ERR_SCHEDULE_NONPOSITIVE = "Schedule must be greater than 0."
_ERR_SCHEDULE_TOO_LARGE = "Schedule is too large (max 10 years)."
def normalize_every(value: str) -> str:
"""Normalize schedule string to lower-case compact unit form."""
normalized = value.strip().lower().replace(" ", "")
if not normalized:
raise ValueError("schedule is required")
raise ValueError(_ERR_SCHEDULE_REQUIRED)
return normalized
def parse_every_to_seconds(value: str) -> int:
"""Parse compact schedule syntax into a number of seconds."""
normalized = normalize_every(value)
match = _DURATION_RE.match(normalized)
if not match:
raise ValueError('Invalid schedule. Expected format like "10m", "1h", "2d", "1w".')
raise ValueError(_ERR_SCHEDULE_INVALID)
num = int(match.group("num"))
unit = match.group("unit").lower()
seconds = num * _MULTIPLIERS[unit]
if seconds <= 0:
raise ValueError("Schedule must be greater than 0.")
raise ValueError(_ERR_SCHEDULE_NONPOSITIVE)
# Prevent accidental absurd schedules (e.g. 999999999d).
if seconds > 60 * 60 * 24 * 365 * 10:
raise ValueError("Schedule is too large (max 10 years).")
if seconds > _MAX_SCHEDULE_SECONDS:
raise ValueError(_ERR_SCHEDULE_TOO_LARGE)
return seconds

View File

@@ -1,8 +1,10 @@
"""Global exception handlers and request-id middleware for FastAPI."""
from __future__ import annotations
import logging
from collections.abc import Awaitable, Callable
from typing import Any, Final, cast
from typing import TYPE_CHECKING, Any, Final, cast
from uuid import uuid4
from fastapi import FastAPI, Request
@@ -10,7 +12,9 @@ from fastapi.exceptions import RequestValidationError, ResponseValidationError
from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.responses import Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
if TYPE_CHECKING:
from starlette.types import ASGIApp, Message, Receive, Scope, Send
logger = logging.getLogger(__name__)
@@ -20,12 +24,16 @@ ExceptionHandler = Callable[[Request, Exception], Response | Awaitable[Response]
class RequestIdMiddleware:
"""ASGI middleware that ensures every request has a request-id."""
def __init__(self, app: ASGIApp, *, header_name: str = REQUEST_ID_HEADER) -> None:
"""Initialize middleware with app instance and header name."""
self._app = app
self._header_name = header_name
self._header_name_bytes = header_name.lower().encode("latin-1")
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Inject request-id into request state and response headers."""
if scope["type"] != "http":
await self._app(scope, receive, send)
return
@@ -36,8 +44,11 @@ class RequestIdMiddleware:
if message["type"] == "http.response.start":
# Starlette uses `list[tuple[bytes, bytes]]` here.
headers: list[tuple[bytes, bytes]] = message.setdefault("headers", [])
if not any(key.lower() == self._header_name_bytes for key, _ in headers):
headers.append((self._header_name_bytes, request_id.encode("latin-1")))
if not any(
key.lower() == self._header_name_bytes for key, _ in headers
):
request_id_bytes = request_id.encode("latin-1")
headers.append((self._header_name_bytes, request_id_bytes))
await send(message)
await self._app(scope, receive, send_with_request_id)
@@ -62,8 +73,10 @@ class RequestIdMiddleware:
def install_error_handling(app: FastAPI) -> None:
"""Install middleware and exception handlers on the FastAPI app."""
# Important: add request-id middleware last so it's the outermost middleware.
# This ensures it still runs even if another middleware (e.g. CORS preflight) returns early.
# This ensures it still runs even if another middleware
# (e.g. CORS preflight) returns early.
app.add_middleware(RequestIdMiddleware)
app.add_exception_handler(
@@ -88,7 +101,7 @@ def _get_request_id(request: Request) -> str | None:
return None
def _error_payload(*, detail: Any, request_id: str | None) -> dict[str, Any]:
def _error_payload(*, detail: object, request_id: str | None) -> dict[str, object]:
payload: dict[str, Any] = {"detail": detail}
if request_id:
payload["request_id"] = request_id
@@ -96,7 +109,8 @@ def _error_payload(*, detail: Any, request_id: str | None) -> dict[str, Any]:
async def _request_validation_handler(
request: Request, exc: RequestValidationError
request: Request,
exc: RequestValidationError,
) -> JSONResponse:
# `RequestValidationError` is expected user input; don't log at ERROR.
request_id = _get_request_id(request)
@@ -107,7 +121,8 @@ async def _request_validation_handler(
async def _response_validation_handler(
request: Request, exc: ResponseValidationError
request: Request,
exc: ResponseValidationError,
) -> JSONResponse:
request_id = _get_request_id(request)
logger.exception(
@@ -125,7 +140,10 @@ async def _response_validation_handler(
)
async def _http_exception_handler(request: Request, exc: StarletteHTTPException) -> JSONResponse:
async def _http_exception_handler(
request: Request,
exc: StarletteHTTPException,
) -> JSONResponse:
request_id = _get_request_id(request)
return JSONResponse(
status_code=exc.status_code,
@@ -134,11 +152,18 @@ async def _http_exception_handler(request: Request, exc: StarletteHTTPException)
)
async def _unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
async def _unhandled_exception_handler(
request: Request,
_exc: Exception,
) -> JSONResponse:
request_id = _get_request_id(request)
logger.exception(
"unhandled_exception",
extra={"request_id": request_id, "method": request.method, "path": request.url.path},
extra={
"request_id": request_id,
"method": request.method,
"path": request.url.path,
},
)
return JSONResponse(
status_code=500,

View File

@@ -1,3 +1,5 @@
"""Application logging configuration and formatter utilities."""
from __future__ import annotations
import json
@@ -15,7 +17,8 @@ TRACE_LEVEL = 5
logging.addLevelName(TRACE_LEVEL, "TRACE")
def _trace(self: logging.Logger, message: str, *args: Any, **kwargs: Any) -> None:
def _trace(self: logging.Logger, message: str, *args: object, **kwargs: object) -> None:
"""Log a TRACE-level message when the logger is TRACE-enabled."""
if self.isEnabledFor(TRACE_LEVEL):
self._log(TRACE_LEVEL, message, args, **kwargs)
@@ -52,21 +55,31 @@ _STANDARD_LOG_RECORD_ATTRS = {
class AppLogFilter(logging.Filter):
"""Inject app metadata into each log record."""
def __init__(self, app_name: str, version: str) -> None:
"""Initialize the filter with fixed app and version values."""
super().__init__()
self._app_name = app_name
self._version = version
def filter(self, record: logging.LogRecord) -> bool:
"""Attach app metadata fields to each emitted record."""
record.app = self._app_name
record.version = self._version
return True
class JsonFormatter(logging.Formatter):
"""Formatter that serializes log records as compact JSON."""
def format(self, record: logging.LogRecord) -> str:
"""Render a single log record into a JSON string."""
payload: dict[str, Any] = {
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
"timestamp": datetime.fromtimestamp(
record.created,
tz=timezone.utc,
).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
@@ -88,7 +101,10 @@ class JsonFormatter(logging.Formatter):
class KeyValueFormatter(logging.Formatter):
"""Formatter that appends extra fields as `key=value` pairs."""
def format(self, record: logging.LogRecord) -> str:
"""Render a log line with appended non-standard record fields."""
base = super().format(record)
extras = {
key: value
@@ -102,6 +118,8 @@ class KeyValueFormatter(logging.Formatter):
class AppLogger:
"""Centralized logging setup utility for the backend process."""
_configured = False
@classmethod
@@ -111,10 +129,12 @@ class AppLogger:
return level_name, TRACE_LEVEL
if level_name.isdigit():
return level_name, int(level_name)
return level_name, logging._nameToLevel.get(level_name, logging.INFO)
levels = logging.getLevelNamesMapping()
return level_name, levels.get(level_name, logging.INFO)
@classmethod
def configure(cls, *, force: bool = False) -> None:
"""Configure root logging handlers, formatters, and library levels."""
if cls._configured and not force:
return
@@ -127,7 +147,8 @@ class AppLogger:
formatter: logging.Formatter = JsonFormatter()
else:
formatter = KeyValueFormatter(
"%(asctime)s %(levelname)s %(name)s %(message)s app=%(app)s version=%(version)s"
"%(asctime)s %(levelname)s %(name)s %(message)s "
"app=%(app)s version=%(version)s",
)
if settings.log_use_utc:
formatter.converter = time.gmtime
@@ -160,10 +181,12 @@ class AppLogger:
@classmethod
def get_logger(cls, name: str | None = None) -> logging.Logger:
"""Return a logger, ensuring logging has been configured."""
if not cls._configured:
cls.configure()
return logging.getLogger(name)
def configure_logging() -> None:
"""Configure global application logging once during startup."""
AppLogger.configure()

View File

@@ -1,3 +1,5 @@
"""Time-related helpers shared across backend modules."""
from __future__ import annotations
from datetime import UTC, datetime
@@ -5,6 +7,5 @@ from datetime import UTC, datetime
def utcnow() -> datetime:
"""Return a naive UTC datetime without using deprecated datetime.utcnow()."""
# Keep naive UTC values for compatibility with existing DB schema/queries.
return datetime.now(UTC).replace(tzinfo=None)

View File

@@ -1,2 +1,4 @@
"""Application name and version constants."""
APP_NAME = "mission-control"
APP_VERSION = "0.1.0"