refactor(auth): streamline authentication flow and enhance error logging

This commit is contained in:
Abhimanyu Saharan
2026-02-10 18:40:42 +05:30
parent 5c25c4bb91
commit 01888062a6
2 changed files with 155 additions and 189 deletions

View File

@@ -4,17 +4,17 @@ from __future__ import annotations
import logging
from dataclasses import dataclass
from time import monotonic
from typing import TYPE_CHECKING, Literal
import httpx
import jwt
from clerk_backend_api import Clerk
from clerk_backend_api.models.clerkerrors import ClerkErrors
from clerk_backend_api.models.sdkerror import SDKError
from clerk_backend_api.security.types import AuthenticateRequestOptions, AuthStatus, RequestState
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel, ValidationError
from starlette.concurrency import run_in_threadpool
from app.core.config import settings
from app.db import crud
@@ -29,9 +29,6 @@ logger = logging.getLogger(__name__)
security = HTTPBearer(auto_error=False)
SECURITY_DEP = Depends(security)
SESSION_DEP = Depends(get_session)
_JWKS_CACHE_TTL_SECONDS = 300.0
_jwks_cache_payload: dict[str, object] | None = None
_jwks_cache_at_monotonic = 0.0
class ClerkTokenPayload(BaseModel):
@@ -107,42 +104,6 @@ def _extract_claim_name(claims: dict[str, object]) -> str | None:
return " ".join(parts)
def _claim_debug_snapshot(claims: dict[str, object]) -> dict[str, object]:
email_addresses = claims.get("email_addresses")
email_samples: list[dict[str, str | None]] = []
if isinstance(email_addresses, list):
for item in email_addresses[:5]:
if isinstance(item, dict):
email_samples.append(
{
"id": _non_empty_str(item.get("id")),
"email": _normalize_email(
item.get("email_address") or item.get("email"),
),
},
)
elif isinstance(item, str):
email_samples.append({"id": None, "email": _normalize_email(item)})
return {
"keys": sorted(claims.keys()),
"iss": _non_empty_str(claims.get("iss")),
"sub": _non_empty_str(claims.get("sub")),
"email": _normalize_email(claims.get("email")),
"email_address": _normalize_email(claims.get("email_address")),
"primary_email_address": _normalize_email(claims.get("primary_email_address")),
"primary_email_address_id": _non_empty_str(claims.get("primary_email_address_id")),
"email_addresses_count": len(email_addresses) if isinstance(email_addresses, list) else 0,
"email_addresses_sample": email_samples,
"name": _non_empty_str(claims.get("name")),
"full_name": _non_empty_str(claims.get("full_name")),
"given_name": _non_empty_str(claims.get("given_name"))
or _non_empty_str(claims.get("first_name")),
"family_name": _non_empty_str(claims.get("family_name"))
or _non_empty_str(claims.get("last_name")),
}
def _extract_clerk_profile(profile: ClerkUser | None) -> tuple[str | None, str | None]:
if profile is None:
return None, None
@@ -192,98 +153,32 @@ def _normalize_clerk_server_url(raw: str) -> str | None:
return server_url
async def _fetch_clerk_jwks(*, force_refresh: bool = False) -> dict[str, object]:
global _jwks_cache_payload
global _jwks_cache_at_monotonic
if (
not force_refresh
and _jwks_cache_payload is not None
and monotonic() - _jwks_cache_at_monotonic < _JWKS_CACHE_TTL_SECONDS
):
return _jwks_cache_payload
secret = settings.clerk_secret_key.strip()
server_url = _normalize_clerk_server_url(settings.clerk_api_url or "")
async with Clerk(
bearer_auth=secret,
server_url=server_url,
timeout_ms=5000,
) as clerk:
jwks = await clerk.jwks.get_async()
if jwks is None:
raise RuntimeError("Clerk JWKS response was empty.")
payload = jwks.model_dump()
if not isinstance(payload, dict):
raise RuntimeError("Clerk JWKS response had invalid shape.")
_jwks_cache_payload = payload
_jwks_cache_at_monotonic = monotonic()
return payload
def _make_authenticate_request_options() -> AuthenticateRequestOptions:
# Follow the clerk-backend-api documented flow: authenticate_request() with a secret key.
return AuthenticateRequestOptions(
secret_key=settings.clerk_secret_key.strip(),
clock_skew_in_ms=int(settings.clerk_leeway * 1000),
accepts_token=["session_token"],
)
def _public_key_for_kid(jwks_payload: dict[str, object], kid: str) -> jwt.PyJWK | None:
try:
jwk_set = jwt.PyJWKSet.from_dict(jwks_payload)
except jwt.PyJWTError:
return None
for key in jwk_set.keys:
if key.key_id == kid:
return key
return None
async def _decode_clerk_token(token: str) -> dict[str, object]:
try:
header = jwt.get_unverified_header(token)
except jwt.PyJWTError as exc:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from exc
kid = _non_empty_str(header.get("kid"))
if kid is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
secret_kind = settings.clerk_secret_key.strip().split("_", maxsplit=1)[0]
for attempt in (False, True):
try:
jwks_payload = await _fetch_clerk_jwks(force_refresh=attempt)
except (ClerkErrors, SDKError, RuntimeError):
logger.warning(
"auth.clerk.jwks.fetch_failed attempt=%s secret_kind=%s",
2 if attempt else 1,
secret_kind,
exc_info=True,
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from None
key = _public_key_for_kid(jwks_payload, kid)
if key is None:
if attempt:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
continue
try:
decoded = jwt.decode(
token,
key=key,
algorithms=["RS256"],
options={
"verify_aud": False,
"verify_iat": settings.clerk_verify_iat,
},
leeway=settings.clerk_leeway,
)
except jwt.PyJWTError as exc:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from exc
if not isinstance(decoded, dict):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
return {str(k): v for k, v in decoded.items()}
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
async def _authenticate_clerk_request(request: Request) -> RequestState:
# The SDK docs use httpx.Request as the request object; build one from the ASGI request.
httpx_request = httpx.Request(
request.method,
str(request.url),
headers=dict(request.headers),
)
options = _make_authenticate_request_options()
sdk = Clerk(bearer_auth=options.secret_key or "")
return await run_in_threadpool(sdk.authenticate_request, httpx_request, options)
async def _fetch_clerk_profile(clerk_user_id: str) -> tuple[str | None, str | None]:
secret = settings.clerk_secret_key.strip()
secret_kind = secret.split("_", maxsplit=1)[0] if "_" in secret else "unknown"
server_url = _normalize_clerk_server_url(settings.clerk_api_url or "")
clerk_user_id_log = clerk_user_id[-6:] if clerk_user_id else ""
try:
async with Clerk(
@@ -293,42 +188,29 @@ async def _fetch_clerk_profile(clerk_user_id: str) -> tuple[str | None, str | No
) as clerk:
profile = await clerk.users.get_async(user_id=clerk_user_id)
email, name = _extract_clerk_profile(profile)
logger.info(
"auth.clerk.profile.fetch clerk_user_id=%s email=%s name=%s",
clerk_user_id,
email,
name,
)
return email, name
except ClerkErrors as exc:
errors_payload = str(exc)
if len(errors_payload) > 300:
errors_payload = f"{errors_payload[:300]}..."
logger.warning(
"auth.clerk.profile.fetch_failed clerk_user_id=%s reason=clerk_errors "
"secret_kind=%s body=%s",
clerk_user_id,
"secret_kind=%s error_type=%s",
clerk_user_id_log,
secret_kind,
errors_payload,
exc.__class__.__name__,
)
except SDKError as exc:
response_body = exc.body.strip() or None
if response_body and len(response_body) > 300:
response_body = f"{response_body[:300]}..."
logger.warning(
"auth.clerk.profile.fetch_failed clerk_user_id=%s status=%s reason=sdk_error "
"server_url=%s secret_kind=%s body=%s",
clerk_user_id,
"server_url=%s secret_kind=%s",
clerk_user_id_log,
exc.status_code,
server_url,
secret_kind,
response_body,
)
except httpx.TimeoutException as exc:
logger.warning(
"auth.clerk.profile.fetch_failed clerk_user_id=%s reason=timeout "
"server_url=%s secret_kind=%s error=%s",
clerk_user_id,
clerk_user_id_log,
server_url,
secret_kind,
str(exc) or exc.__class__.__name__,
@@ -337,7 +219,7 @@ async def _fetch_clerk_profile(clerk_user_id: str) -> tuple[str | None, str | No
logger.warning(
"auth.clerk.profile.fetch_failed clerk_user_id=%s reason=sdk_exception "
"error_type=%s error=%s",
clerk_user_id,
clerk_user_id_log,
exc.__class__.__name__,
str(exc)[:300],
)
@@ -349,6 +231,7 @@ async def delete_clerk_user(clerk_user_id: str) -> None:
secret = settings.clerk_secret_key.strip()
secret_kind = secret.split("_", maxsplit=1)[0] if "_" in secret else "unknown"
server_url = _normalize_clerk_server_url(settings.clerk_api_url or "")
clerk_user_id_log = clerk_user_id[-6:] if clerk_user_id else ""
try:
async with Clerk(
@@ -357,17 +240,14 @@ async def delete_clerk_user(clerk_user_id: str) -> None:
timeout_ms=5000,
) as clerk:
await clerk.users.delete_async(user_id=clerk_user_id)
logger.info("auth.clerk.user.delete clerk_user_id=%s", clerk_user_id)
logger.info("auth.clerk.user.delete clerk_user_id=%s", clerk_user_id_log)
except ClerkErrors as exc:
errors_payload = str(exc)
if len(errors_payload) > 300:
errors_payload = f"{errors_payload[:300]}..."
logger.warning(
"auth.clerk.user.delete_failed clerk_user_id=%s reason=clerk_errors "
"secret_kind=%s body=%s",
clerk_user_id,
"secret_kind=%s error_type=%s",
clerk_user_id_log,
secret_kind,
errors_payload,
exc.__class__.__name__,
)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
@@ -375,19 +255,15 @@ async def delete_clerk_user(clerk_user_id: str) -> None:
) from exc
except SDKError as exc:
if exc.status_code == 404:
logger.info("auth.clerk.user.delete_missing clerk_user_id=%s", clerk_user_id)
logger.info("auth.clerk.user.delete_missing clerk_user_id=%s", clerk_user_id_log)
return
response_body = exc.body.strip() or None
if response_body and len(response_body) > 300:
response_body = f"{response_body[:300]}..."
logger.warning(
"auth.clerk.user.delete_failed clerk_user_id=%s status=%s reason=sdk_error "
"server_url=%s secret_kind=%s body=%s",
clerk_user_id,
"server_url=%s secret_kind=%s",
clerk_user_id_log,
exc.status_code,
server_url,
secret_kind,
response_body,
)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
@@ -396,7 +272,7 @@ async def delete_clerk_user(clerk_user_id: str) -> None:
except Exception as exc:
logger.warning(
"auth.clerk.user.delete_failed clerk_user_id=%s reason=sdk_exception",
clerk_user_id,
clerk_user_id_log,
exc_info=True,
)
raise HTTPException(
@@ -411,6 +287,7 @@ async def _get_or_sync_user(
clerk_user_id: str,
claims: dict[str, object],
) -> User:
clerk_user_id_log = clerk_user_id[-6:] if clerk_user_id else ""
claim_email = _extract_claim_email(claims)
claim_name = _extract_claim_name(claims)
defaults: dict[str, object | None] = {
@@ -434,16 +311,6 @@ async def _get_or_sync_user(
email = profile_email or claim_email
name = profile_name or claim_name
logger.info(
"auth.claims.parsed clerk_user_id=%s extracted_email=%s extracted_name=%s "
"claim_email=%s claim_name=%s claims=%s",
clerk_user_id,
profile_email,
profile_name,
claim_email,
claim_name,
_claim_debug_snapshot(claims),
)
changed = False
if email and user.email != email:
@@ -456,18 +323,23 @@ async def _get_or_sync_user(
session.add(user)
await session.commit()
await session.refresh(user)
logger.info(
"auth.user.sync clerk_user_id=%s updated=%s claim_email=%s final_email=%s",
clerk_user_id,
changed,
_normalize_email(claim_email),
_normalize_email(user.email),
)
logger.info(
"auth.user.sync clerk_user_id=%s updated=%s fetched_profile=%s",
clerk_user_id_log,
changed,
should_fetch_profile,
)
else:
logger.debug(
"auth.user.sync clerk_user_id=%s updated=%s fetched_profile=%s",
clerk_user_id_log,
changed,
should_fetch_profile,
)
if not user.email:
logger.warning(
"auth.user.sync.missing_email clerk_user_id=%s claims=%s",
clerk_user_id,
_claim_debug_snapshot(claims),
"auth.user.sync.missing_email clerk_user_id=%s",
clerk_user_id_log,
)
return user
@@ -478,14 +350,15 @@ def _parse_subject(claims: dict[str, object]) -> str | None:
async def get_auth_context(
request: Request,
credentials: HTTPAuthorizationCredentials | None = SECURITY_DEP,
session: AsyncSession = SESSION_DEP,
) -> AuthContext:
"""Resolve required authenticated user context from Clerk JWT headers."""
if credentials is None:
request_state = await _authenticate_clerk_request(request)
if request_state.status != AuthStatus.SIGNED_IN or not isinstance(request_state.payload, dict):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
claims = await _decode_clerk_token(credentials.credentials)
claims: dict[str, object] = {str(k): v for k, v in request_state.payload.items()}
try:
clerk_user_id = _parse_subject(claims)
except ValidationError as exc:
@@ -516,13 +389,10 @@ async def get_auth_context_optional(
"""Resolve user context if available, otherwise return `None`."""
if request.headers.get("X-Agent-Token"):
return None
if credentials is None:
return None
try:
claims = await _decode_clerk_token(credentials.credentials)
except HTTPException:
request_state = await _authenticate_clerk_request(request)
if request_state.status != AuthStatus.SIGNED_IN or not isinstance(request_state.payload, dict):
return None
claims: dict[str, object] = {str(k): v for k, v in request_state.payload.items()}
try:
clerk_user_id = _parse_subject(claims)

View File

@@ -0,0 +1,96 @@
# ruff: noqa: SLF001
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
import pytest
from fastapi import HTTPException
from app.core import auth
from app.models.users import User
class _FakeSession:
async def commit(self) -> None: # pragma: no cover
raise AssertionError("commit should not be called in these tests")
@pytest.mark.asyncio
async def test_get_auth_context_raises_401_when_clerk_signed_out(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from clerk_backend_api.security.types import AuthStatus, RequestState
async def _fake_authenticate(_request: Any) -> RequestState:
return RequestState(status=AuthStatus.SIGNED_OUT)
monkeypatch.setattr(auth, "_authenticate_clerk_request", _fake_authenticate)
with pytest.raises(HTTPException) as excinfo:
await auth.get_auth_context( # type: ignore[arg-type]
request=SimpleNamespace(headers={}),
credentials=None,
session=_FakeSession(), # type: ignore[arg-type]
)
assert excinfo.value.status_code == 401
@pytest.mark.asyncio
async def test_get_auth_context_uses_request_state_payload_claims(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from clerk_backend_api.security.types import AuthStatus, RequestState
async def _fake_authenticate(_request: Any) -> RequestState:
return RequestState(status=AuthStatus.SIGNED_IN, token="t", payload={"sub": "user_123"})
async def _fake_get_or_sync_user(
_session: Any,
*,
clerk_user_id: str,
claims: dict[str, object],
) -> User:
assert clerk_user_id == "user_123"
assert claims["sub"] == "user_123"
return User(clerk_user_id="user_123", email="user@example.com", name="User")
async def _fake_ensure_member_for_user(_session: Any, _user: User) -> None:
return None
monkeypatch.setattr(auth, "_authenticate_clerk_request", _fake_authenticate)
monkeypatch.setattr(auth, "_get_or_sync_user", _fake_get_or_sync_user)
import app.services.organizations as orgs
monkeypatch.setattr(orgs, "ensure_member_for_user", _fake_ensure_member_for_user)
ctx = await auth.get_auth_context( # type: ignore[arg-type]
request=SimpleNamespace(headers={}),
credentials=None,
session=_FakeSession(), # type: ignore[arg-type]
)
assert ctx.actor_type == "user"
assert ctx.user is not None
assert ctx.user.clerk_user_id == "user_123"
@pytest.mark.asyncio
async def test_get_auth_context_optional_returns_none_for_agent_token(
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def _boom(_request: Any) -> Any: # pragma: no cover
raise AssertionError("_authenticate_clerk_request should not be called")
monkeypatch.setattr(auth, "_authenticate_clerk_request", _boom)
out = await auth.get_auth_context_optional( # type: ignore[arg-type]
request=SimpleNamespace(headers={"X-Agent-Token": "agent"}),
credentials=None,
session=_FakeSession(), # type: ignore[arg-type]
)
assert out is None