refactor(auth): streamline authentication flow and enhance error logging

This commit is contained in:
Abhimanyu Saharan
2026-02-10 18:40:42 +05:30
parent 5c25c4bb91
commit 01888062a6
2 changed files with 155 additions and 189 deletions

View File

@@ -4,17 +4,17 @@ from __future__ import annotations
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from time import monotonic
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING, Literal
import httpx import httpx
import jwt
from clerk_backend_api import Clerk from clerk_backend_api import Clerk
from clerk_backend_api.models.clerkerrors import ClerkErrors from clerk_backend_api.models.clerkerrors import ClerkErrors
from clerk_backend_api.models.sdkerror import SDKError from clerk_backend_api.models.sdkerror import SDKError
from clerk_backend_api.security.types import AuthenticateRequestOptions, AuthStatus, RequestState
from fastapi import Depends, HTTPException, Request, status from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from starlette.concurrency import run_in_threadpool
from app.core.config import settings from app.core.config import settings
from app.db import crud from app.db import crud
@@ -29,9 +29,6 @@ logger = logging.getLogger(__name__)
security = HTTPBearer(auto_error=False) security = HTTPBearer(auto_error=False)
SECURITY_DEP = Depends(security) SECURITY_DEP = Depends(security)
SESSION_DEP = Depends(get_session) SESSION_DEP = Depends(get_session)
_JWKS_CACHE_TTL_SECONDS = 300.0
_jwks_cache_payload: dict[str, object] | None = None
_jwks_cache_at_monotonic = 0.0
class ClerkTokenPayload(BaseModel): class ClerkTokenPayload(BaseModel):
@@ -107,42 +104,6 @@ def _extract_claim_name(claims: dict[str, object]) -> str | None:
return " ".join(parts) return " ".join(parts)
def _claim_debug_snapshot(claims: dict[str, object]) -> dict[str, object]:
email_addresses = claims.get("email_addresses")
email_samples: list[dict[str, str | None]] = []
if isinstance(email_addresses, list):
for item in email_addresses[:5]:
if isinstance(item, dict):
email_samples.append(
{
"id": _non_empty_str(item.get("id")),
"email": _normalize_email(
item.get("email_address") or item.get("email"),
),
},
)
elif isinstance(item, str):
email_samples.append({"id": None, "email": _normalize_email(item)})
return {
"keys": sorted(claims.keys()),
"iss": _non_empty_str(claims.get("iss")),
"sub": _non_empty_str(claims.get("sub")),
"email": _normalize_email(claims.get("email")),
"email_address": _normalize_email(claims.get("email_address")),
"primary_email_address": _normalize_email(claims.get("primary_email_address")),
"primary_email_address_id": _non_empty_str(claims.get("primary_email_address_id")),
"email_addresses_count": len(email_addresses) if isinstance(email_addresses, list) else 0,
"email_addresses_sample": email_samples,
"name": _non_empty_str(claims.get("name")),
"full_name": _non_empty_str(claims.get("full_name")),
"given_name": _non_empty_str(claims.get("given_name"))
or _non_empty_str(claims.get("first_name")),
"family_name": _non_empty_str(claims.get("family_name"))
or _non_empty_str(claims.get("last_name")),
}
def _extract_clerk_profile(profile: ClerkUser | None) -> tuple[str | None, str | None]: def _extract_clerk_profile(profile: ClerkUser | None) -> tuple[str | None, str | None]:
if profile is None: if profile is None:
return None, None return None, None
@@ -192,98 +153,32 @@ def _normalize_clerk_server_url(raw: str) -> str | None:
return server_url return server_url
async def _fetch_clerk_jwks(*, force_refresh: bool = False) -> dict[str, object]: def _make_authenticate_request_options() -> AuthenticateRequestOptions:
global _jwks_cache_payload # Follow the clerk-backend-api documented flow: authenticate_request() with a secret key.
global _jwks_cache_at_monotonic return AuthenticateRequestOptions(
secret_key=settings.clerk_secret_key.strip(),
if ( clock_skew_in_ms=int(settings.clerk_leeway * 1000),
not force_refresh accepts_token=["session_token"],
and _jwks_cache_payload is not None
and monotonic() - _jwks_cache_at_monotonic < _JWKS_CACHE_TTL_SECONDS
):
return _jwks_cache_payload
secret = settings.clerk_secret_key.strip()
server_url = _normalize_clerk_server_url(settings.clerk_api_url or "")
async with Clerk(
bearer_auth=secret,
server_url=server_url,
timeout_ms=5000,
) as clerk:
jwks = await clerk.jwks.get_async()
if jwks is None:
raise RuntimeError("Clerk JWKS response was empty.")
payload = jwks.model_dump()
if not isinstance(payload, dict):
raise RuntimeError("Clerk JWKS response had invalid shape.")
_jwks_cache_payload = payload
_jwks_cache_at_monotonic = monotonic()
return payload
def _public_key_for_kid(jwks_payload: dict[str, object], kid: str) -> jwt.PyJWK | None:
try:
jwk_set = jwt.PyJWKSet.from_dict(jwks_payload)
except jwt.PyJWTError:
return None
for key in jwk_set.keys:
if key.key_id == kid:
return key
return None
async def _decode_clerk_token(token: str) -> dict[str, object]:
try:
header = jwt.get_unverified_header(token)
except jwt.PyJWTError as exc:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from exc
kid = _non_empty_str(header.get("kid"))
if kid is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
secret_kind = settings.clerk_secret_key.strip().split("_", maxsplit=1)[0]
for attempt in (False, True):
try:
jwks_payload = await _fetch_clerk_jwks(force_refresh=attempt)
except (ClerkErrors, SDKError, RuntimeError):
logger.warning(
"auth.clerk.jwks.fetch_failed attempt=%s secret_kind=%s",
2 if attempt else 1,
secret_kind,
exc_info=True,
) )
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from None
key = _public_key_for_kid(jwks_payload, kid)
if key is None: async def _authenticate_clerk_request(request: Request) -> RequestState:
if attempt: # The SDK docs use httpx.Request as the request object; build one from the ASGI request.
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) httpx_request = httpx.Request(
continue request.method,
try: str(request.url),
decoded = jwt.decode( headers=dict(request.headers),
token,
key=key,
algorithms=["RS256"],
options={
"verify_aud": False,
"verify_iat": settings.clerk_verify_iat,
},
leeway=settings.clerk_leeway,
) )
except jwt.PyJWTError as exc: options = _make_authenticate_request_options()
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from exc sdk = Clerk(bearer_auth=options.secret_key or "")
if not isinstance(decoded, dict): return await run_in_threadpool(sdk.authenticate_request, httpx_request, options)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
return {str(k): v for k, v in decoded.items()}
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
async def _fetch_clerk_profile(clerk_user_id: str) -> tuple[str | None, str | None]: async def _fetch_clerk_profile(clerk_user_id: str) -> tuple[str | None, str | None]:
secret = settings.clerk_secret_key.strip() secret = settings.clerk_secret_key.strip()
secret_kind = secret.split("_", maxsplit=1)[0] if "_" in secret else "unknown" secret_kind = secret.split("_", maxsplit=1)[0] if "_" in secret else "unknown"
server_url = _normalize_clerk_server_url(settings.clerk_api_url or "") server_url = _normalize_clerk_server_url(settings.clerk_api_url or "")
clerk_user_id_log = clerk_user_id[-6:] if clerk_user_id else ""
try: try:
async with Clerk( async with Clerk(
@@ -293,42 +188,29 @@ async def _fetch_clerk_profile(clerk_user_id: str) -> tuple[str | None, str | No
) as clerk: ) as clerk:
profile = await clerk.users.get_async(user_id=clerk_user_id) profile = await clerk.users.get_async(user_id=clerk_user_id)
email, name = _extract_clerk_profile(profile) email, name = _extract_clerk_profile(profile)
logger.info(
"auth.clerk.profile.fetch clerk_user_id=%s email=%s name=%s",
clerk_user_id,
email,
name,
)
return email, name return email, name
except ClerkErrors as exc: except ClerkErrors as exc:
errors_payload = str(exc)
if len(errors_payload) > 300:
errors_payload = f"{errors_payload[:300]}..."
logger.warning( logger.warning(
"auth.clerk.profile.fetch_failed clerk_user_id=%s reason=clerk_errors " "auth.clerk.profile.fetch_failed clerk_user_id=%s reason=clerk_errors "
"secret_kind=%s body=%s", "secret_kind=%s error_type=%s",
clerk_user_id, clerk_user_id_log,
secret_kind, secret_kind,
errors_payload, exc.__class__.__name__,
) )
except SDKError as exc: except SDKError as exc:
response_body = exc.body.strip() or None
if response_body and len(response_body) > 300:
response_body = f"{response_body[:300]}..."
logger.warning( logger.warning(
"auth.clerk.profile.fetch_failed clerk_user_id=%s status=%s reason=sdk_error " "auth.clerk.profile.fetch_failed clerk_user_id=%s status=%s reason=sdk_error "
"server_url=%s secret_kind=%s body=%s", "server_url=%s secret_kind=%s",
clerk_user_id, clerk_user_id_log,
exc.status_code, exc.status_code,
server_url, server_url,
secret_kind, secret_kind,
response_body,
) )
except httpx.TimeoutException as exc: except httpx.TimeoutException as exc:
logger.warning( logger.warning(
"auth.clerk.profile.fetch_failed clerk_user_id=%s reason=timeout " "auth.clerk.profile.fetch_failed clerk_user_id=%s reason=timeout "
"server_url=%s secret_kind=%s error=%s", "server_url=%s secret_kind=%s error=%s",
clerk_user_id, clerk_user_id_log,
server_url, server_url,
secret_kind, secret_kind,
str(exc) or exc.__class__.__name__, str(exc) or exc.__class__.__name__,
@@ -337,7 +219,7 @@ async def _fetch_clerk_profile(clerk_user_id: str) -> tuple[str | None, str | No
logger.warning( logger.warning(
"auth.clerk.profile.fetch_failed clerk_user_id=%s reason=sdk_exception " "auth.clerk.profile.fetch_failed clerk_user_id=%s reason=sdk_exception "
"error_type=%s error=%s", "error_type=%s error=%s",
clerk_user_id, clerk_user_id_log,
exc.__class__.__name__, exc.__class__.__name__,
str(exc)[:300], str(exc)[:300],
) )
@@ -349,6 +231,7 @@ async def delete_clerk_user(clerk_user_id: str) -> None:
secret = settings.clerk_secret_key.strip() secret = settings.clerk_secret_key.strip()
secret_kind = secret.split("_", maxsplit=1)[0] if "_" in secret else "unknown" secret_kind = secret.split("_", maxsplit=1)[0] if "_" in secret else "unknown"
server_url = _normalize_clerk_server_url(settings.clerk_api_url or "") server_url = _normalize_clerk_server_url(settings.clerk_api_url or "")
clerk_user_id_log = clerk_user_id[-6:] if clerk_user_id else ""
try: try:
async with Clerk( async with Clerk(
@@ -357,17 +240,14 @@ async def delete_clerk_user(clerk_user_id: str) -> None:
timeout_ms=5000, timeout_ms=5000,
) as clerk: ) as clerk:
await clerk.users.delete_async(user_id=clerk_user_id) await clerk.users.delete_async(user_id=clerk_user_id)
logger.info("auth.clerk.user.delete clerk_user_id=%s", clerk_user_id) logger.info("auth.clerk.user.delete clerk_user_id=%s", clerk_user_id_log)
except ClerkErrors as exc: except ClerkErrors as exc:
errors_payload = str(exc)
if len(errors_payload) > 300:
errors_payload = f"{errors_payload[:300]}..."
logger.warning( logger.warning(
"auth.clerk.user.delete_failed clerk_user_id=%s reason=clerk_errors " "auth.clerk.user.delete_failed clerk_user_id=%s reason=clerk_errors "
"secret_kind=%s body=%s", "secret_kind=%s error_type=%s",
clerk_user_id, clerk_user_id_log,
secret_kind, secret_kind,
errors_payload, exc.__class__.__name__,
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, status_code=status.HTTP_502_BAD_GATEWAY,
@@ -375,19 +255,15 @@ async def delete_clerk_user(clerk_user_id: str) -> None:
) from exc ) from exc
except SDKError as exc: except SDKError as exc:
if exc.status_code == 404: if exc.status_code == 404:
logger.info("auth.clerk.user.delete_missing clerk_user_id=%s", clerk_user_id) logger.info("auth.clerk.user.delete_missing clerk_user_id=%s", clerk_user_id_log)
return return
response_body = exc.body.strip() or None
if response_body and len(response_body) > 300:
response_body = f"{response_body[:300]}..."
logger.warning( logger.warning(
"auth.clerk.user.delete_failed clerk_user_id=%s status=%s reason=sdk_error " "auth.clerk.user.delete_failed clerk_user_id=%s status=%s reason=sdk_error "
"server_url=%s secret_kind=%s body=%s", "server_url=%s secret_kind=%s",
clerk_user_id, clerk_user_id_log,
exc.status_code, exc.status_code,
server_url, server_url,
secret_kind, secret_kind,
response_body,
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, status_code=status.HTTP_502_BAD_GATEWAY,
@@ -396,7 +272,7 @@ async def delete_clerk_user(clerk_user_id: str) -> None:
except Exception as exc: except Exception as exc:
logger.warning( logger.warning(
"auth.clerk.user.delete_failed clerk_user_id=%s reason=sdk_exception", "auth.clerk.user.delete_failed clerk_user_id=%s reason=sdk_exception",
clerk_user_id, clerk_user_id_log,
exc_info=True, exc_info=True,
) )
raise HTTPException( raise HTTPException(
@@ -411,6 +287,7 @@ async def _get_or_sync_user(
clerk_user_id: str, clerk_user_id: str,
claims: dict[str, object], claims: dict[str, object],
) -> User: ) -> User:
clerk_user_id_log = clerk_user_id[-6:] if clerk_user_id else ""
claim_email = _extract_claim_email(claims) claim_email = _extract_claim_email(claims)
claim_name = _extract_claim_name(claims) claim_name = _extract_claim_name(claims)
defaults: dict[str, object | None] = { defaults: dict[str, object | None] = {
@@ -434,16 +311,6 @@ async def _get_or_sync_user(
email = profile_email or claim_email email = profile_email or claim_email
name = profile_name or claim_name name = profile_name or claim_name
logger.info(
"auth.claims.parsed clerk_user_id=%s extracted_email=%s extracted_name=%s "
"claim_email=%s claim_name=%s claims=%s",
clerk_user_id,
profile_email,
profile_name,
claim_email,
claim_name,
_claim_debug_snapshot(claims),
)
changed = False changed = False
if email and user.email != email: if email and user.email != email:
@@ -457,17 +324,22 @@ async def _get_or_sync_user(
await session.commit() await session.commit()
await session.refresh(user) await session.refresh(user)
logger.info( logger.info(
"auth.user.sync clerk_user_id=%s updated=%s claim_email=%s final_email=%s", "auth.user.sync clerk_user_id=%s updated=%s fetched_profile=%s",
clerk_user_id, clerk_user_id_log,
changed, changed,
_normalize_email(claim_email), should_fetch_profile,
_normalize_email(user.email), )
else:
logger.debug(
"auth.user.sync clerk_user_id=%s updated=%s fetched_profile=%s",
clerk_user_id_log,
changed,
should_fetch_profile,
) )
if not user.email: if not user.email:
logger.warning( logger.warning(
"auth.user.sync.missing_email clerk_user_id=%s claims=%s", "auth.user.sync.missing_email clerk_user_id=%s",
clerk_user_id, clerk_user_id_log,
_claim_debug_snapshot(claims),
) )
return user return user
@@ -478,14 +350,15 @@ def _parse_subject(claims: dict[str, object]) -> str | None:
async def get_auth_context( async def get_auth_context(
request: Request,
credentials: HTTPAuthorizationCredentials | None = SECURITY_DEP, credentials: HTTPAuthorizationCredentials | None = SECURITY_DEP,
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
) -> AuthContext: ) -> AuthContext:
"""Resolve required authenticated user context from Clerk JWT headers.""" """Resolve required authenticated user context from Clerk JWT headers."""
if credentials is None: request_state = await _authenticate_clerk_request(request)
if request_state.status != AuthStatus.SIGNED_IN or not isinstance(request_state.payload, dict):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
claims: dict[str, object] = {str(k): v for k, v in request_state.payload.items()}
claims = await _decode_clerk_token(credentials.credentials)
try: try:
clerk_user_id = _parse_subject(claims) clerk_user_id = _parse_subject(claims)
except ValidationError as exc: except ValidationError as exc:
@@ -516,13 +389,10 @@ async def get_auth_context_optional(
"""Resolve user context if available, otherwise return `None`.""" """Resolve user context if available, otherwise return `None`."""
if request.headers.get("X-Agent-Token"): if request.headers.get("X-Agent-Token"):
return None return None
if credentials is None: request_state = await _authenticate_clerk_request(request)
return None if request_state.status != AuthStatus.SIGNED_IN or not isinstance(request_state.payload, dict):
try:
claims = await _decode_clerk_token(credentials.credentials)
except HTTPException:
return None return None
claims: dict[str, object] = {str(k): v for k, v in request_state.payload.items()}
try: try:
clerk_user_id = _parse_subject(claims) clerk_user_id = _parse_subject(claims)

View File

@@ -0,0 +1,96 @@
# ruff: noqa: SLF001
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
import pytest
from fastapi import HTTPException
from app.core import auth
from app.models.users import User
class _FakeSession:
async def commit(self) -> None: # pragma: no cover
raise AssertionError("commit should not be called in these tests")
@pytest.mark.asyncio
async def test_get_auth_context_raises_401_when_clerk_signed_out(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from clerk_backend_api.security.types import AuthStatus, RequestState
async def _fake_authenticate(_request: Any) -> RequestState:
return RequestState(status=AuthStatus.SIGNED_OUT)
monkeypatch.setattr(auth, "_authenticate_clerk_request", _fake_authenticate)
with pytest.raises(HTTPException) as excinfo:
await auth.get_auth_context( # type: ignore[arg-type]
request=SimpleNamespace(headers={}),
credentials=None,
session=_FakeSession(), # type: ignore[arg-type]
)
assert excinfo.value.status_code == 401
@pytest.mark.asyncio
async def test_get_auth_context_uses_request_state_payload_claims(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from clerk_backend_api.security.types import AuthStatus, RequestState
async def _fake_authenticate(_request: Any) -> RequestState:
return RequestState(status=AuthStatus.SIGNED_IN, token="t", payload={"sub": "user_123"})
async def _fake_get_or_sync_user(
_session: Any,
*,
clerk_user_id: str,
claims: dict[str, object],
) -> User:
assert clerk_user_id == "user_123"
assert claims["sub"] == "user_123"
return User(clerk_user_id="user_123", email="user@example.com", name="User")
async def _fake_ensure_member_for_user(_session: Any, _user: User) -> None:
return None
monkeypatch.setattr(auth, "_authenticate_clerk_request", _fake_authenticate)
monkeypatch.setattr(auth, "_get_or_sync_user", _fake_get_or_sync_user)
import app.services.organizations as orgs
monkeypatch.setattr(orgs, "ensure_member_for_user", _fake_ensure_member_for_user)
ctx = await auth.get_auth_context( # type: ignore[arg-type]
request=SimpleNamespace(headers={}),
credentials=None,
session=_FakeSession(), # type: ignore[arg-type]
)
assert ctx.actor_type == "user"
assert ctx.user is not None
assert ctx.user.clerk_user_id == "user_123"
@pytest.mark.asyncio
async def test_get_auth_context_optional_returns_none_for_agent_token(
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def _boom(_request: Any) -> Any: # pragma: no cover
raise AssertionError("_authenticate_clerk_request should not be called")
monkeypatch.setattr(auth, "_authenticate_clerk_request", _boom)
out = await auth.get_auth_context_optional( # type: ignore[arg-type]
request=SimpleNamespace(headers={"X-Agent-Token": "agent"}),
credentials=None,
session=_FakeSession(), # type: ignore[arg-type]
)
assert out is None