refactor(auth): streamline authentication flow and enhance error logging
This commit is contained in:
@@ -4,17 +4,17 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from time import monotonic
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
from clerk_backend_api import Clerk
|
||||
from clerk_backend_api.models.clerkerrors import ClerkErrors
|
||||
from clerk_backend_api.models.sdkerror import SDKError
|
||||
from clerk_backend_api.security.types import AuthenticateRequestOptions, AuthStatus, RequestState
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db import crud
|
||||
@@ -29,9 +29,6 @@ logger = logging.getLogger(__name__)
|
||||
security = HTTPBearer(auto_error=False)
|
||||
SECURITY_DEP = Depends(security)
|
||||
SESSION_DEP = Depends(get_session)
|
||||
_JWKS_CACHE_TTL_SECONDS = 300.0
|
||||
_jwks_cache_payload: dict[str, object] | None = None
|
||||
_jwks_cache_at_monotonic = 0.0
|
||||
|
||||
|
||||
class ClerkTokenPayload(BaseModel):
|
||||
@@ -107,42 +104,6 @@ def _extract_claim_name(claims: dict[str, object]) -> str | None:
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _claim_debug_snapshot(claims: dict[str, object]) -> dict[str, object]:
|
||||
email_addresses = claims.get("email_addresses")
|
||||
email_samples: list[dict[str, str | None]] = []
|
||||
if isinstance(email_addresses, list):
|
||||
for item in email_addresses[:5]:
|
||||
if isinstance(item, dict):
|
||||
email_samples.append(
|
||||
{
|
||||
"id": _non_empty_str(item.get("id")),
|
||||
"email": _normalize_email(
|
||||
item.get("email_address") or item.get("email"),
|
||||
),
|
||||
},
|
||||
)
|
||||
elif isinstance(item, str):
|
||||
email_samples.append({"id": None, "email": _normalize_email(item)})
|
||||
|
||||
return {
|
||||
"keys": sorted(claims.keys()),
|
||||
"iss": _non_empty_str(claims.get("iss")),
|
||||
"sub": _non_empty_str(claims.get("sub")),
|
||||
"email": _normalize_email(claims.get("email")),
|
||||
"email_address": _normalize_email(claims.get("email_address")),
|
||||
"primary_email_address": _normalize_email(claims.get("primary_email_address")),
|
||||
"primary_email_address_id": _non_empty_str(claims.get("primary_email_address_id")),
|
||||
"email_addresses_count": len(email_addresses) if isinstance(email_addresses, list) else 0,
|
||||
"email_addresses_sample": email_samples,
|
||||
"name": _non_empty_str(claims.get("name")),
|
||||
"full_name": _non_empty_str(claims.get("full_name")),
|
||||
"given_name": _non_empty_str(claims.get("given_name"))
|
||||
or _non_empty_str(claims.get("first_name")),
|
||||
"family_name": _non_empty_str(claims.get("family_name"))
|
||||
or _non_empty_str(claims.get("last_name")),
|
||||
}
|
||||
|
||||
|
||||
def _extract_clerk_profile(profile: ClerkUser | None) -> tuple[str | None, str | None]:
|
||||
if profile is None:
|
||||
return None, None
|
||||
@@ -192,98 +153,32 @@ def _normalize_clerk_server_url(raw: str) -> str | None:
|
||||
return server_url
|
||||
|
||||
|
||||
async def _fetch_clerk_jwks(*, force_refresh: bool = False) -> dict[str, object]:
|
||||
global _jwks_cache_payload
|
||||
global _jwks_cache_at_monotonic
|
||||
|
||||
if (
|
||||
not force_refresh
|
||||
and _jwks_cache_payload is not None
|
||||
and monotonic() - _jwks_cache_at_monotonic < _JWKS_CACHE_TTL_SECONDS
|
||||
):
|
||||
return _jwks_cache_payload
|
||||
|
||||
secret = settings.clerk_secret_key.strip()
|
||||
server_url = _normalize_clerk_server_url(settings.clerk_api_url or "")
|
||||
async with Clerk(
|
||||
bearer_auth=secret,
|
||||
server_url=server_url,
|
||||
timeout_ms=5000,
|
||||
) as clerk:
|
||||
jwks = await clerk.jwks.get_async()
|
||||
if jwks is None:
|
||||
raise RuntimeError("Clerk JWKS response was empty.")
|
||||
payload = jwks.model_dump()
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError("Clerk JWKS response had invalid shape.")
|
||||
_jwks_cache_payload = payload
|
||||
_jwks_cache_at_monotonic = monotonic()
|
||||
return payload
|
||||
def _make_authenticate_request_options() -> AuthenticateRequestOptions:
|
||||
# Follow the clerk-backend-api documented flow: authenticate_request() with a secret key.
|
||||
return AuthenticateRequestOptions(
|
||||
secret_key=settings.clerk_secret_key.strip(),
|
||||
clock_skew_in_ms=int(settings.clerk_leeway * 1000),
|
||||
accepts_token=["session_token"],
|
||||
)
|
||||
|
||||
|
||||
def _public_key_for_kid(jwks_payload: dict[str, object], kid: str) -> jwt.PyJWK | None:
|
||||
try:
|
||||
jwk_set = jwt.PyJWKSet.from_dict(jwks_payload)
|
||||
except jwt.PyJWTError:
|
||||
return None
|
||||
for key in jwk_set.keys:
|
||||
if key.key_id == kid:
|
||||
return key
|
||||
return None
|
||||
|
||||
|
||||
async def _decode_clerk_token(token: str) -> dict[str, object]:
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
except jwt.PyJWTError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from exc
|
||||
|
||||
kid = _non_empty_str(header.get("kid"))
|
||||
if kid is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
secret_kind = settings.clerk_secret_key.strip().split("_", maxsplit=1)[0]
|
||||
for attempt in (False, True):
|
||||
try:
|
||||
jwks_payload = await _fetch_clerk_jwks(force_refresh=attempt)
|
||||
except (ClerkErrors, SDKError, RuntimeError):
|
||||
logger.warning(
|
||||
"auth.clerk.jwks.fetch_failed attempt=%s secret_kind=%s",
|
||||
2 if attempt else 1,
|
||||
secret_kind,
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from None
|
||||
|
||||
key = _public_key_for_kid(jwks_payload, kid)
|
||||
if key is None:
|
||||
if attempt:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
continue
|
||||
try:
|
||||
decoded = jwt.decode(
|
||||
token,
|
||||
key=key,
|
||||
algorithms=["RS256"],
|
||||
options={
|
||||
"verify_aud": False,
|
||||
"verify_iat": settings.clerk_verify_iat,
|
||||
},
|
||||
leeway=settings.clerk_leeway,
|
||||
)
|
||||
except jwt.PyJWTError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from exc
|
||||
if not isinstance(decoded, dict):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
return {str(k): v for k, v in decoded.items()}
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
async def _authenticate_clerk_request(request: Request) -> RequestState:
|
||||
# The SDK docs use httpx.Request as the request object; build one from the ASGI request.
|
||||
httpx_request = httpx.Request(
|
||||
request.method,
|
||||
str(request.url),
|
||||
headers=dict(request.headers),
|
||||
)
|
||||
options = _make_authenticate_request_options()
|
||||
sdk = Clerk(bearer_auth=options.secret_key or "")
|
||||
return await run_in_threadpool(sdk.authenticate_request, httpx_request, options)
|
||||
|
||||
|
||||
async def _fetch_clerk_profile(clerk_user_id: str) -> tuple[str | None, str | None]:
|
||||
secret = settings.clerk_secret_key.strip()
|
||||
secret_kind = secret.split("_", maxsplit=1)[0] if "_" in secret else "unknown"
|
||||
server_url = _normalize_clerk_server_url(settings.clerk_api_url or "")
|
||||
clerk_user_id_log = clerk_user_id[-6:] if clerk_user_id else ""
|
||||
|
||||
try:
|
||||
async with Clerk(
|
||||
@@ -293,42 +188,29 @@ async def _fetch_clerk_profile(clerk_user_id: str) -> tuple[str | None, str | No
|
||||
) as clerk:
|
||||
profile = await clerk.users.get_async(user_id=clerk_user_id)
|
||||
email, name = _extract_clerk_profile(profile)
|
||||
logger.info(
|
||||
"auth.clerk.profile.fetch clerk_user_id=%s email=%s name=%s",
|
||||
clerk_user_id,
|
||||
email,
|
||||
name,
|
||||
)
|
||||
return email, name
|
||||
except ClerkErrors as exc:
|
||||
errors_payload = str(exc)
|
||||
if len(errors_payload) > 300:
|
||||
errors_payload = f"{errors_payload[:300]}..."
|
||||
logger.warning(
|
||||
"auth.clerk.profile.fetch_failed clerk_user_id=%s reason=clerk_errors "
|
||||
"secret_kind=%s body=%s",
|
||||
clerk_user_id,
|
||||
"secret_kind=%s error_type=%s",
|
||||
clerk_user_id_log,
|
||||
secret_kind,
|
||||
errors_payload,
|
||||
exc.__class__.__name__,
|
||||
)
|
||||
except SDKError as exc:
|
||||
response_body = exc.body.strip() or None
|
||||
if response_body and len(response_body) > 300:
|
||||
response_body = f"{response_body[:300]}..."
|
||||
logger.warning(
|
||||
"auth.clerk.profile.fetch_failed clerk_user_id=%s status=%s reason=sdk_error "
|
||||
"server_url=%s secret_kind=%s body=%s",
|
||||
clerk_user_id,
|
||||
"server_url=%s secret_kind=%s",
|
||||
clerk_user_id_log,
|
||||
exc.status_code,
|
||||
server_url,
|
||||
secret_kind,
|
||||
response_body,
|
||||
)
|
||||
except httpx.TimeoutException as exc:
|
||||
logger.warning(
|
||||
"auth.clerk.profile.fetch_failed clerk_user_id=%s reason=timeout "
|
||||
"server_url=%s secret_kind=%s error=%s",
|
||||
clerk_user_id,
|
||||
clerk_user_id_log,
|
||||
server_url,
|
||||
secret_kind,
|
||||
str(exc) or exc.__class__.__name__,
|
||||
@@ -337,7 +219,7 @@ async def _fetch_clerk_profile(clerk_user_id: str) -> tuple[str | None, str | No
|
||||
logger.warning(
|
||||
"auth.clerk.profile.fetch_failed clerk_user_id=%s reason=sdk_exception "
|
||||
"error_type=%s error=%s",
|
||||
clerk_user_id,
|
||||
clerk_user_id_log,
|
||||
exc.__class__.__name__,
|
||||
str(exc)[:300],
|
||||
)
|
||||
@@ -349,6 +231,7 @@ async def delete_clerk_user(clerk_user_id: str) -> None:
|
||||
secret = settings.clerk_secret_key.strip()
|
||||
secret_kind = secret.split("_", maxsplit=1)[0] if "_" in secret else "unknown"
|
||||
server_url = _normalize_clerk_server_url(settings.clerk_api_url or "")
|
||||
clerk_user_id_log = clerk_user_id[-6:] if clerk_user_id else ""
|
||||
|
||||
try:
|
||||
async with Clerk(
|
||||
@@ -357,17 +240,14 @@ async def delete_clerk_user(clerk_user_id: str) -> None:
|
||||
timeout_ms=5000,
|
||||
) as clerk:
|
||||
await clerk.users.delete_async(user_id=clerk_user_id)
|
||||
logger.info("auth.clerk.user.delete clerk_user_id=%s", clerk_user_id)
|
||||
logger.info("auth.clerk.user.delete clerk_user_id=%s", clerk_user_id_log)
|
||||
except ClerkErrors as exc:
|
||||
errors_payload = str(exc)
|
||||
if len(errors_payload) > 300:
|
||||
errors_payload = f"{errors_payload[:300]}..."
|
||||
logger.warning(
|
||||
"auth.clerk.user.delete_failed clerk_user_id=%s reason=clerk_errors "
|
||||
"secret_kind=%s body=%s",
|
||||
clerk_user_id,
|
||||
"secret_kind=%s error_type=%s",
|
||||
clerk_user_id_log,
|
||||
secret_kind,
|
||||
errors_payload,
|
||||
exc.__class__.__name__,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
@@ -375,19 +255,15 @@ async def delete_clerk_user(clerk_user_id: str) -> None:
|
||||
) from exc
|
||||
except SDKError as exc:
|
||||
if exc.status_code == 404:
|
||||
logger.info("auth.clerk.user.delete_missing clerk_user_id=%s", clerk_user_id)
|
||||
logger.info("auth.clerk.user.delete_missing clerk_user_id=%s", clerk_user_id_log)
|
||||
return
|
||||
response_body = exc.body.strip() or None
|
||||
if response_body and len(response_body) > 300:
|
||||
response_body = f"{response_body[:300]}..."
|
||||
logger.warning(
|
||||
"auth.clerk.user.delete_failed clerk_user_id=%s status=%s reason=sdk_error "
|
||||
"server_url=%s secret_kind=%s body=%s",
|
||||
clerk_user_id,
|
||||
"server_url=%s secret_kind=%s",
|
||||
clerk_user_id_log,
|
||||
exc.status_code,
|
||||
server_url,
|
||||
secret_kind,
|
||||
response_body,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
@@ -396,7 +272,7 @@ async def delete_clerk_user(clerk_user_id: str) -> None:
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"auth.clerk.user.delete_failed clerk_user_id=%s reason=sdk_exception",
|
||||
clerk_user_id,
|
||||
clerk_user_id_log,
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
@@ -411,6 +287,7 @@ async def _get_or_sync_user(
|
||||
clerk_user_id: str,
|
||||
claims: dict[str, object],
|
||||
) -> User:
|
||||
clerk_user_id_log = clerk_user_id[-6:] if clerk_user_id else ""
|
||||
claim_email = _extract_claim_email(claims)
|
||||
claim_name = _extract_claim_name(claims)
|
||||
defaults: dict[str, object | None] = {
|
||||
@@ -434,16 +311,6 @@ async def _get_or_sync_user(
|
||||
|
||||
email = profile_email or claim_email
|
||||
name = profile_name or claim_name
|
||||
logger.info(
|
||||
"auth.claims.parsed clerk_user_id=%s extracted_email=%s extracted_name=%s "
|
||||
"claim_email=%s claim_name=%s claims=%s",
|
||||
clerk_user_id,
|
||||
profile_email,
|
||||
profile_name,
|
||||
claim_email,
|
||||
claim_name,
|
||||
_claim_debug_snapshot(claims),
|
||||
)
|
||||
|
||||
changed = False
|
||||
if email and user.email != email:
|
||||
@@ -456,18 +323,23 @@ async def _get_or_sync_user(
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
logger.info(
|
||||
"auth.user.sync clerk_user_id=%s updated=%s claim_email=%s final_email=%s",
|
||||
clerk_user_id,
|
||||
changed,
|
||||
_normalize_email(claim_email),
|
||||
_normalize_email(user.email),
|
||||
)
|
||||
logger.info(
|
||||
"auth.user.sync clerk_user_id=%s updated=%s fetched_profile=%s",
|
||||
clerk_user_id_log,
|
||||
changed,
|
||||
should_fetch_profile,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"auth.user.sync clerk_user_id=%s updated=%s fetched_profile=%s",
|
||||
clerk_user_id_log,
|
||||
changed,
|
||||
should_fetch_profile,
|
||||
)
|
||||
if not user.email:
|
||||
logger.warning(
|
||||
"auth.user.sync.missing_email clerk_user_id=%s claims=%s",
|
||||
clerk_user_id,
|
||||
_claim_debug_snapshot(claims),
|
||||
"auth.user.sync.missing_email clerk_user_id=%s",
|
||||
clerk_user_id_log,
|
||||
)
|
||||
return user
|
||||
|
||||
@@ -478,14 +350,15 @@ def _parse_subject(claims: dict[str, object]) -> str | None:
|
||||
|
||||
|
||||
async def get_auth_context(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = SECURITY_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
) -> AuthContext:
|
||||
"""Resolve required authenticated user context from Clerk JWT headers."""
|
||||
if credentials is None:
|
||||
request_state = await _authenticate_clerk_request(request)
|
||||
if request_state.status != AuthStatus.SIGNED_IN or not isinstance(request_state.payload, dict):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
claims = await _decode_clerk_token(credentials.credentials)
|
||||
claims: dict[str, object] = {str(k): v for k, v in request_state.payload.items()}
|
||||
try:
|
||||
clerk_user_id = _parse_subject(claims)
|
||||
except ValidationError as exc:
|
||||
@@ -516,13 +389,10 @@ async def get_auth_context_optional(
|
||||
"""Resolve user context if available, otherwise return `None`."""
|
||||
if request.headers.get("X-Agent-Token"):
|
||||
return None
|
||||
if credentials is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
claims = await _decode_clerk_token(credentials.credentials)
|
||||
except HTTPException:
|
||||
request_state = await _authenticate_clerk_request(request)
|
||||
if request_state.status != AuthStatus.SIGNED_IN or not isinstance(request_state.payload, dict):
|
||||
return None
|
||||
claims: dict[str, object] = {str(k): v for k, v in request_state.payload.items()}
|
||||
|
||||
try:
|
||||
clerk_user_id = _parse_subject(claims)
|
||||
|
||||
96
backend/tests/test_authenticate_request_flow.py
Normal file
96
backend/tests/test_authenticate_request_flow.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# ruff: noqa: SLF001
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.core import auth
|
||||
from app.models.users import User
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
async def commit(self) -> None: # pragma: no cover
|
||||
raise AssertionError("commit should not be called in these tests")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_auth_context_raises_401_when_clerk_signed_out(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from clerk_backend_api.security.types import AuthStatus, RequestState
|
||||
|
||||
async def _fake_authenticate(_request: Any) -> RequestState:
|
||||
return RequestState(status=AuthStatus.SIGNED_OUT)
|
||||
|
||||
monkeypatch.setattr(auth, "_authenticate_clerk_request", _fake_authenticate)
|
||||
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
await auth.get_auth_context( # type: ignore[arg-type]
|
||||
request=SimpleNamespace(headers={}),
|
||||
credentials=None,
|
||||
session=_FakeSession(), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
assert excinfo.value.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_auth_context_uses_request_state_payload_claims(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from clerk_backend_api.security.types import AuthStatus, RequestState
|
||||
|
||||
async def _fake_authenticate(_request: Any) -> RequestState:
|
||||
return RequestState(status=AuthStatus.SIGNED_IN, token="t", payload={"sub": "user_123"})
|
||||
|
||||
async def _fake_get_or_sync_user(
|
||||
_session: Any,
|
||||
*,
|
||||
clerk_user_id: str,
|
||||
claims: dict[str, object],
|
||||
) -> User:
|
||||
assert clerk_user_id == "user_123"
|
||||
assert claims["sub"] == "user_123"
|
||||
return User(clerk_user_id="user_123", email="user@example.com", name="User")
|
||||
|
||||
async def _fake_ensure_member_for_user(_session: Any, _user: User) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(auth, "_authenticate_clerk_request", _fake_authenticate)
|
||||
monkeypatch.setattr(auth, "_get_or_sync_user", _fake_get_or_sync_user)
|
||||
|
||||
import app.services.organizations as orgs
|
||||
|
||||
monkeypatch.setattr(orgs, "ensure_member_for_user", _fake_ensure_member_for_user)
|
||||
|
||||
ctx = await auth.get_auth_context( # type: ignore[arg-type]
|
||||
request=SimpleNamespace(headers={}),
|
||||
credentials=None,
|
||||
session=_FakeSession(), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
assert ctx.actor_type == "user"
|
||||
assert ctx.user is not None
|
||||
assert ctx.user.clerk_user_id == "user_123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_auth_context_optional_returns_none_for_agent_token(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _boom(_request: Any) -> Any: # pragma: no cover
|
||||
raise AssertionError("_authenticate_clerk_request should not be called")
|
||||
|
||||
monkeypatch.setattr(auth, "_authenticate_clerk_request", _boom)
|
||||
|
||||
out = await auth.get_auth_context_optional( # type: ignore[arg-type]
|
||||
request=SimpleNamespace(headers={"X-Agent-Token": "agent"}),
|
||||
credentials=None,
|
||||
session=_FakeSession(), # type: ignore[arg-type]
|
||||
)
|
||||
assert out is None
|
||||
|
||||
Reference in New Issue
Block a user