diff --git a/backend/app/core/auth.py b/backend/app/core/auth.py index d40a9e56..5d67b32a 100644 --- a/backend/app/core/auth.py +++ b/backend/app/core/auth.py @@ -4,17 +4,17 @@ from __future__ import annotations import logging from dataclasses import dataclass -from time import monotonic from typing import TYPE_CHECKING, Literal import httpx -import jwt from clerk_backend_api import Clerk from clerk_backend_api.models.clerkerrors import ClerkErrors from clerk_backend_api.models.sdkerror import SDKError +from clerk_backend_api.security.types import AuthenticateRequestOptions, AuthStatus, RequestState from fastapi import Depends, HTTPException, Request, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from pydantic import BaseModel, ValidationError +from starlette.concurrency import run_in_threadpool from app.core.config import settings from app.db import crud @@ -29,9 +29,6 @@ logger = logging.getLogger(__name__) security = HTTPBearer(auto_error=False) SECURITY_DEP = Depends(security) SESSION_DEP = Depends(get_session) -_JWKS_CACHE_TTL_SECONDS = 300.0 -_jwks_cache_payload: dict[str, object] | None = None -_jwks_cache_at_monotonic = 0.0 class ClerkTokenPayload(BaseModel): @@ -107,42 +104,6 @@ def _extract_claim_name(claims: dict[str, object]) -> str | None: return " ".join(parts) -def _claim_debug_snapshot(claims: dict[str, object]) -> dict[str, object]: - email_addresses = claims.get("email_addresses") - email_samples: list[dict[str, str | None]] = [] - if isinstance(email_addresses, list): - for item in email_addresses[:5]: - if isinstance(item, dict): - email_samples.append( - { - "id": _non_empty_str(item.get("id")), - "email": _normalize_email( - item.get("email_address") or item.get("email"), - ), - }, - ) - elif isinstance(item, str): - email_samples.append({"id": None, "email": _normalize_email(item)}) - - return { - "keys": sorted(claims.keys()), - "iss": _non_empty_str(claims.get("iss")), - "sub": _non_empty_str(claims.get("sub")), - "email": _normalize_email(claims.get("email")), - "email_address": _normalize_email(claims.get("email_address")), - "primary_email_address": _normalize_email(claims.get("primary_email_address")), - "primary_email_address_id": _non_empty_str(claims.get("primary_email_address_id")), - "email_addresses_count": len(email_addresses) if isinstance(email_addresses, list) else 0, - "email_addresses_sample": email_samples, - "name": _non_empty_str(claims.get("name")), - "full_name": _non_empty_str(claims.get("full_name")), - "given_name": _non_empty_str(claims.get("given_name")) - or _non_empty_str(claims.get("first_name")), - "family_name": _non_empty_str(claims.get("family_name")) - or _non_empty_str(claims.get("last_name")), - } - - def _extract_clerk_profile(profile: ClerkUser | None) -> tuple[str | None, str | None]: if profile is None: return None, None @@ -192,98 +153,32 @@ def _normalize_clerk_server_url(raw: str) -> str | None: return server_url -async def _fetch_clerk_jwks(*, force_refresh: bool = False) -> dict[str, object]: - global _jwks_cache_payload - global _jwks_cache_at_monotonic - - if ( - not force_refresh - and _jwks_cache_payload is not None - and monotonic() - _jwks_cache_at_monotonic < _JWKS_CACHE_TTL_SECONDS - ): - return _jwks_cache_payload - - secret = settings.clerk_secret_key.strip() - server_url = _normalize_clerk_server_url(settings.clerk_api_url or "") - async with Clerk( - bearer_auth=secret, - server_url=server_url, - timeout_ms=5000, - ) as clerk: - jwks = await clerk.jwks.get_async() - if jwks is None: - raise RuntimeError("Clerk JWKS response was empty.") - payload = jwks.model_dump() - if not isinstance(payload, dict): - raise RuntimeError("Clerk JWKS response had invalid shape.") - _jwks_cache_payload = payload - _jwks_cache_at_monotonic = monotonic() - return payload +def _make_authenticate_request_options() -> AuthenticateRequestOptions: + # Follow the clerk-backend-api documented flow: authenticate_request() with a secret key. + return AuthenticateRequestOptions( + secret_key=settings.clerk_secret_key.strip(), + clock_skew_in_ms=int(settings.clerk_leeway * 1000), + accepts_token=["session_token"], + ) -def _public_key_for_kid(jwks_payload: dict[str, object], kid: str) -> jwt.PyJWK | None: - try: - jwk_set = jwt.PyJWKSet.from_dict(jwks_payload) - except jwt.PyJWTError: - return None - for key in jwk_set.keys: - if key.key_id == kid: - return key - return None - - -async def _decode_clerk_token(token: str) -> dict[str, object]: - try: - header = jwt.get_unverified_header(token) - except jwt.PyJWTError as exc: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from exc - - kid = _non_empty_str(header.get("kid")) - if kid is None: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - - secret_kind = settings.clerk_secret_key.strip().split("_", maxsplit=1)[0] - for attempt in (False, True): - try: - jwks_payload = await _fetch_clerk_jwks(force_refresh=attempt) - except (ClerkErrors, SDKError, RuntimeError): - logger.warning( - "auth.clerk.jwks.fetch_failed attempt=%s secret_kind=%s", - 2 if attempt else 1, - secret_kind, - exc_info=True, - ) - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from None - - key = _public_key_for_kid(jwks_payload, kid) - if key is None: - if attempt: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - continue - try: - decoded = jwt.decode( - token, - key=key, - algorithms=["RS256"], - options={ - "verify_aud": False, - "verify_iat": settings.clerk_verify_iat, - }, - leeway=settings.clerk_leeway, - ) - except jwt.PyJWTError as exc: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from exc - if not isinstance(decoded, dict): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - return {str(k): v for k, v in decoded.items()} - - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) +async def _authenticate_clerk_request(request: Request) -> RequestState: + # The SDK docs use httpx.Request as the request object; build one from the ASGI request. + httpx_request = httpx.Request( + request.method, + str(request.url), + headers=dict(request.headers), + ) + options = _make_authenticate_request_options() + sdk = Clerk(bearer_auth=options.secret_key or "") + return await run_in_threadpool(sdk.authenticate_request, httpx_request, options) async def _fetch_clerk_profile(clerk_user_id: str) -> tuple[str | None, str | None]: secret = settings.clerk_secret_key.strip() secret_kind = secret.split("_", maxsplit=1)[0] if "_" in secret else "unknown" server_url = _normalize_clerk_server_url(settings.clerk_api_url or "") + clerk_user_id_log = clerk_user_id[-6:] if clerk_user_id else "" try: async with Clerk( @@ -293,42 +188,29 @@ async def _fetch_clerk_profile(clerk_user_id: str) -> tuple[str | None, str | No ) as clerk: profile = await clerk.users.get_async(user_id=clerk_user_id) email, name = _extract_clerk_profile(profile) - logger.info( - "auth.clerk.profile.fetch clerk_user_id=%s email=%s name=%s", - clerk_user_id, - email, - name, - ) return email, name except ClerkErrors as exc: - errors_payload = str(exc) - if len(errors_payload) > 300: - errors_payload = f"{errors_payload[:300]}..." logger.warning( "auth.clerk.profile.fetch_failed clerk_user_id=%s reason=clerk_errors " - "secret_kind=%s body=%s", - clerk_user_id, + "secret_kind=%s error_type=%s", + clerk_user_id_log, secret_kind, - errors_payload, + exc.__class__.__name__, ) except SDKError as exc: - response_body = exc.body.strip() or None - if response_body and len(response_body) > 300: - response_body = f"{response_body[:300]}..." logger.warning( "auth.clerk.profile.fetch_failed clerk_user_id=%s status=%s reason=sdk_error " - "server_url=%s secret_kind=%s body=%s", - clerk_user_id, + "server_url=%s secret_kind=%s", + clerk_user_id_log, exc.status_code, server_url, secret_kind, - response_body, ) except httpx.TimeoutException as exc: logger.warning( "auth.clerk.profile.fetch_failed clerk_user_id=%s reason=timeout " "server_url=%s secret_kind=%s error=%s", - clerk_user_id, + clerk_user_id_log, server_url, secret_kind, str(exc) or exc.__class__.__name__, @@ -337,7 +219,7 @@ async def _fetch_clerk_profile(clerk_user_id: str) -> tuple[str | None, str | No logger.warning( "auth.clerk.profile.fetch_failed clerk_user_id=%s reason=sdk_exception " "error_type=%s error=%s", - clerk_user_id, + clerk_user_id_log, exc.__class__.__name__, str(exc)[:300], ) @@ -349,6 +231,7 @@ async def delete_clerk_user(clerk_user_id: str) -> None: secret = settings.clerk_secret_key.strip() secret_kind = secret.split("_", maxsplit=1)[0] if "_" in secret else "unknown" server_url = _normalize_clerk_server_url(settings.clerk_api_url or "") + clerk_user_id_log = clerk_user_id[-6:] if clerk_user_id else "" try: async with Clerk( @@ -357,17 +240,14 @@ async def delete_clerk_user(clerk_user_id: str) -> None: timeout_ms=5000, ) as clerk: await clerk.users.delete_async(user_id=clerk_user_id) - logger.info("auth.clerk.user.delete clerk_user_id=%s", clerk_user_id) + logger.info("auth.clerk.user.delete clerk_user_id=%s", clerk_user_id_log) except ClerkErrors as exc: - errors_payload = str(exc) - if len(errors_payload) > 300: - errors_payload = f"{errors_payload[:300]}..." logger.warning( "auth.clerk.user.delete_failed clerk_user_id=%s reason=clerk_errors " - "secret_kind=%s body=%s", - clerk_user_id, + "secret_kind=%s error_type=%s", + clerk_user_id_log, secret_kind, - errors_payload, + exc.__class__.__name__, ) raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, @@ -375,19 +255,15 @@ async def delete_clerk_user(clerk_user_id: str) -> None: ) from exc except SDKError as exc: if exc.status_code == 404: - logger.info("auth.clerk.user.delete_missing clerk_user_id=%s", clerk_user_id) + logger.info("auth.clerk.user.delete_missing clerk_user_id=%s", clerk_user_id_log) return - response_body = exc.body.strip() or None - if response_body and len(response_body) > 300: - response_body = f"{response_body[:300]}..." logger.warning( "auth.clerk.user.delete_failed clerk_user_id=%s status=%s reason=sdk_error " - "server_url=%s secret_kind=%s body=%s", - clerk_user_id, + "server_url=%s secret_kind=%s", + clerk_user_id_log, exc.status_code, server_url, secret_kind, - response_body, ) raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, @@ -396,7 +272,7 @@ async def delete_clerk_user(clerk_user_id: str) -> None: except Exception as exc: logger.warning( "auth.clerk.user.delete_failed clerk_user_id=%s reason=sdk_exception", - clerk_user_id, + clerk_user_id_log, exc_info=True, ) raise HTTPException( @@ -411,6 +287,7 @@ async def _get_or_sync_user( clerk_user_id: str, claims: dict[str, object], ) -> User: + clerk_user_id_log = clerk_user_id[-6:] if clerk_user_id else "" claim_email = _extract_claim_email(claims) claim_name = _extract_claim_name(claims) defaults: dict[str, object | None] = { @@ -434,16 +311,6 @@ async def _get_or_sync_user( email = profile_email or claim_email name = profile_name or claim_name - logger.info( - "auth.claims.parsed clerk_user_id=%s extracted_email=%s extracted_name=%s " - "claim_email=%s claim_name=%s claims=%s", - clerk_user_id, - profile_email, - profile_name, - claim_email, - claim_name, - _claim_debug_snapshot(claims), - ) changed = False if email and user.email != email: @@ -456,18 +323,23 @@ async def _get_or_sync_user( session.add(user) await session.commit() await session.refresh(user) - logger.info( - "auth.user.sync clerk_user_id=%s updated=%s claim_email=%s final_email=%s", - clerk_user_id, - changed, - _normalize_email(claim_email), - _normalize_email(user.email), - ) + logger.info( + "auth.user.sync clerk_user_id=%s updated=%s fetched_profile=%s", + clerk_user_id_log, + changed, + should_fetch_profile, + ) + else: + logger.debug( + "auth.user.sync clerk_user_id=%s updated=%s fetched_profile=%s", + clerk_user_id_log, + changed, + should_fetch_profile, + ) if not user.email: logger.warning( - "auth.user.sync.missing_email clerk_user_id=%s claims=%s", - clerk_user_id, - _claim_debug_snapshot(claims), + "auth.user.sync.missing_email clerk_user_id=%s", + clerk_user_id_log, ) return user @@ -478,14 +350,15 @@ def _parse_subject(claims: dict[str, object]) -> str | None: async def get_auth_context( + request: Request, credentials: HTTPAuthorizationCredentials | None = SECURITY_DEP, session: AsyncSession = SESSION_DEP, ) -> AuthContext: """Resolve required authenticated user context from Clerk JWT headers.""" - if credentials is None: + request_state = await _authenticate_clerk_request(request) + if request_state.status != AuthStatus.SIGNED_IN or not isinstance(request_state.payload, dict): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - - claims = await _decode_clerk_token(credentials.credentials) + claims: dict[str, object] = {str(k): v for k, v in request_state.payload.items()} try: clerk_user_id = _parse_subject(claims) except ValidationError as exc: @@ -516,13 +389,10 @@ async def get_auth_context_optional( """Resolve user context if available, otherwise return `None`.""" if request.headers.get("X-Agent-Token"): return None - if credentials is None: - return None - - try: - claims = await _decode_clerk_token(credentials.credentials) - except HTTPException: + request_state = await _authenticate_clerk_request(request) + if request_state.status != AuthStatus.SIGNED_IN or not isinstance(request_state.payload, dict): return None + claims: dict[str, object] = {str(k): v for k, v in request_state.payload.items()} try: clerk_user_id = _parse_subject(claims) diff --git a/backend/tests/test_authenticate_request_flow.py b/backend/tests/test_authenticate_request_flow.py new file mode 100644 index 00000000..e721180a --- /dev/null +++ b/backend/tests/test_authenticate_request_flow.py @@ -0,0 +1,96 @@ +# ruff: noqa: SLF001 + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest +from fastapi import HTTPException + +from app.core import auth +from app.models.users import User + + +class _FakeSession: + async def commit(self) -> None: # pragma: no cover + raise AssertionError("commit should not be called in these tests") + + +@pytest.mark.asyncio +async def test_get_auth_context_raises_401_when_clerk_signed_out( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from clerk_backend_api.security.types import AuthStatus, RequestState + + async def _fake_authenticate(_request: Any) -> RequestState: + return RequestState(status=AuthStatus.SIGNED_OUT) + + monkeypatch.setattr(auth, "_authenticate_clerk_request", _fake_authenticate) + + with pytest.raises(HTTPException) as excinfo: + await auth.get_auth_context( # type: ignore[arg-type] + request=SimpleNamespace(headers={}), + credentials=None, + session=_FakeSession(), # type: ignore[arg-type] + ) + + assert excinfo.value.status_code == 401 + + +@pytest.mark.asyncio +async def test_get_auth_context_uses_request_state_payload_claims( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from clerk_backend_api.security.types import AuthStatus, RequestState + + async def _fake_authenticate(_request: Any) -> RequestState: + return RequestState(status=AuthStatus.SIGNED_IN, token="t", payload={"sub": "user_123"}) + + async def _fake_get_or_sync_user( + _session: Any, + *, + clerk_user_id: str, + claims: dict[str, object], + ) -> User: + assert clerk_user_id == "user_123" + assert claims["sub"] == "user_123" + return User(clerk_user_id="user_123", email="user@example.com", name="User") + + async def _fake_ensure_member_for_user(_session: Any, _user: User) -> None: + return None + + monkeypatch.setattr(auth, "_authenticate_clerk_request", _fake_authenticate) + monkeypatch.setattr(auth, "_get_or_sync_user", _fake_get_or_sync_user) + + import app.services.organizations as orgs + + monkeypatch.setattr(orgs, "ensure_member_for_user", _fake_ensure_member_for_user) + + ctx = await auth.get_auth_context( # type: ignore[arg-type] + request=SimpleNamespace(headers={}), + credentials=None, + session=_FakeSession(), # type: ignore[arg-type] + ) + + assert ctx.actor_type == "user" + assert ctx.user is not None + assert ctx.user.clerk_user_id == "user_123" + + +@pytest.mark.asyncio +async def test_get_auth_context_optional_returns_none_for_agent_token( + monkeypatch: pytest.MonkeyPatch, +) -> None: + async def _boom(_request: Any) -> Any: # pragma: no cover + raise AssertionError("_authenticate_clerk_request should not be called") + + monkeypatch.setattr(auth, "_authenticate_clerk_request", _boom) + + out = await auth.get_auth_context_optional( # type: ignore[arg-type] + request=SimpleNamespace(headers={"X-Agent-Token": "agent"}), + credentials=None, + session=_FakeSession(), # type: ignore[arg-type] + ) + assert out is None +