refactor: update Clerk authentication integration and improve organization handling

This commit is contained in:
Abhimanyu Saharan
2026-02-09 23:55:52 +05:30
parent 6f76e430f4
commit 3326100205
13 changed files with 763 additions and 269 deletions

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import re
from typing import TYPE_CHECKING, Any
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel import SQLModel, col, select
@@ -64,7 +65,6 @@ from app.services.task_dependencies import (
if TYPE_CHECKING:
from collections.abc import Sequence
from uuid import UUID
from fastapi_pagination.limit_offset import LimitOffsetPage
from sqlmodel.ext.asyncio.session import AsyncSession

View File

@@ -2,14 +2,17 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
from functools import lru_cache
from time import monotonic
from typing import TYPE_CHECKING, Literal
import jwt
from clerk_backend_api import Clerk
from clerk_backend_api.models.clerkerrors import ClerkErrors
from clerk_backend_api.models.sdkerror import SDKError
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi_clerk_auth import ClerkConfig, ClerkHTTPBearer
from fastapi_clerk_auth import HTTPAuthorizationCredentials as ClerkCredentials
from pydantic import BaseModel, ValidationError
from app.core.config import settings
@@ -18,12 +21,16 @@ from app.db.session import get_session
from app.models.users import User
if TYPE_CHECKING:
from clerk_backend_api.models.user import User as ClerkUser
from sqlmodel.ext.asyncio.session import AsyncSession
logger = logging.getLogger(__name__)
security = HTTPBearer(auto_error=False)
SECURITY_DEP = Depends(security)
SESSION_DEP = Depends(get_session)
CLERK_JWKS_URL_REQUIRED_ERROR = "CLERK_JWKS_URL is not set."
_JWKS_CACHE_TTL_SECONDS = 300.0
_jwks_cache_payload: dict[str, object] | None = None
_jwks_cache_at_monotonic = 0.0
class ClerkTokenPayload(BaseModel):
@@ -32,19 +39,6 @@ class ClerkTokenPayload(BaseModel):
sub: str
@lru_cache
def _build_clerk_http_bearer(*, auto_error: bool) -> ClerkHTTPBearer:
"""Create and cache the Clerk HTTP bearer guard."""
if not settings.clerk_jwks_url:
raise RuntimeError(CLERK_JWKS_URL_REQUIRED_ERROR)
clerk_config = ClerkConfig(
jwks_url=settings.clerk_jwks_url,
verify_iat=settings.clerk_verify_iat,
leeway=settings.clerk_leeway,
)
return ClerkHTTPBearer(config=clerk_config, auto_error=auto_error, add_state=True)
@dataclass
class AuthContext:
"""Authenticated user context resolved from inbound auth headers."""
@@ -53,25 +47,349 @@ class AuthContext:
user: User | None = None
def _resolve_clerk_auth(
request: Request,
fallback: ClerkCredentials | None,
) -> ClerkCredentials | None:
auth_data = getattr(request.state, "clerk_auth", None)
if isinstance(auth_data, ClerkCredentials):
return auth_data
return fallback
def _parse_subject(auth_data: ClerkCredentials | None) -> str | None:
if not auth_data or not auth_data.decoded:
def _non_empty_str(value: object) -> str | None:
if not isinstance(value, str):
return None
payload = ClerkTokenPayload.model_validate(auth_data.decoded)
cleaned = value.strip()
return cleaned or None
def _normalize_email(value: object) -> str | None:
text = _non_empty_str(value)
if text is None:
return None
return text.lower()
def _extract_claim_email(claims: dict[str, object]) -> str | None:
for key in ("email", "email_address", "primary_email_address"):
email = _normalize_email(claims.get(key))
if email:
return email
primary_email_id = _non_empty_str(claims.get("primary_email_address_id"))
email_addresses = claims.get("email_addresses")
if not isinstance(email_addresses, list):
return None
fallback_email: str | None = None
for item in email_addresses:
if isinstance(item, str):
normalized = _normalize_email(item)
if normalized and fallback_email is None:
fallback_email = normalized
continue
if not isinstance(item, dict):
continue
candidate = _normalize_email(item.get("email_address") or item.get("email"))
if not candidate:
continue
candidate_id = _non_empty_str(item.get("id"))
if primary_email_id and candidate_id == primary_email_id:
return candidate
if fallback_email is None:
fallback_email = candidate
return fallback_email
def _extract_claim_name(claims: dict[str, object]) -> str | None:
for key in ("name", "full_name"):
text = _non_empty_str(claims.get(key))
if text:
return text
first = _non_empty_str(claims.get("given_name")) or _non_empty_str(claims.get("first_name"))
last = _non_empty_str(claims.get("family_name")) or _non_empty_str(claims.get("last_name"))
parts = [part for part in (first, last) if part]
if not parts:
return None
return " ".join(parts)
def _claim_debug_snapshot(claims: dict[str, object]) -> dict[str, object]:
email_addresses = claims.get("email_addresses")
email_samples: list[dict[str, str | None]] = []
if isinstance(email_addresses, list):
for item in email_addresses[:5]:
if isinstance(item, dict):
email_samples.append(
{
"id": _non_empty_str(item.get("id")),
"email": _normalize_email(
item.get("email_address") or item.get("email"),
),
},
)
elif isinstance(item, str):
email_samples.append({"id": None, "email": _normalize_email(item)})
return {
"keys": sorted(claims.keys()),
"iss": _non_empty_str(claims.get("iss")),
"sub": _non_empty_str(claims.get("sub")),
"email": _normalize_email(claims.get("email")),
"email_address": _normalize_email(claims.get("email_address")),
"primary_email_address": _normalize_email(claims.get("primary_email_address")),
"primary_email_address_id": _non_empty_str(claims.get("primary_email_address_id")),
"email_addresses_count": len(email_addresses) if isinstance(email_addresses, list) else 0,
"email_addresses_sample": email_samples,
"name": _non_empty_str(claims.get("name")),
"full_name": _non_empty_str(claims.get("full_name")),
"given_name": _non_empty_str(claims.get("given_name"))
or _non_empty_str(claims.get("first_name")),
"family_name": _non_empty_str(claims.get("family_name"))
or _non_empty_str(claims.get("last_name")),
}
def _extract_clerk_profile(profile: ClerkUser | None) -> tuple[str | None, str | None]:
if profile is None:
return None, None
profile_email = _normalize_email(getattr(profile, "email_address", None))
primary_email_id = _non_empty_str(getattr(profile, "primary_email_address_id", None))
emails = getattr(profile, "email_addresses", None)
if not profile_email and isinstance(emails, list):
fallback_email: str | None = None
for item in emails:
candidate = _normalize_email(
getattr(item, "email_address", None),
)
if not candidate:
continue
candidate_id = _non_empty_str(getattr(item, "id", None))
if primary_email_id and candidate_id == primary_email_id:
profile_email = candidate
break
if fallback_email is None:
fallback_email = candidate
if profile_email is None:
profile_email = fallback_email
profile_name = (
_non_empty_str(getattr(profile, "full_name", None))
or _non_empty_str(getattr(profile, "name", None))
or _non_empty_str(getattr(profile, "first_name", None))
or _non_empty_str(getattr(profile, "username", None))
)
if not profile_name:
first = _non_empty_str(getattr(profile, "first_name", None))
last = _non_empty_str(getattr(profile, "last_name", None))
parts = [part for part in (first, last) if part]
if parts:
profile_name = " ".join(parts)
return profile_email, profile_name
def _normalize_clerk_server_url(raw: str) -> str | None:
server_url = raw.strip().rstrip("/")
if not server_url:
return None
if not server_url.endswith("/v1"):
server_url = f"{server_url}/v1"
return server_url
async def _fetch_clerk_jwks(*, force_refresh: bool = False) -> dict[str, object]:
global _jwks_cache_payload
global _jwks_cache_at_monotonic
if (
not force_refresh
and _jwks_cache_payload is not None
and monotonic() - _jwks_cache_at_monotonic < _JWKS_CACHE_TTL_SECONDS
):
return _jwks_cache_payload
secret = settings.clerk_secret_key.strip()
server_url = _normalize_clerk_server_url(settings.clerk_api_url or "")
async with Clerk(
bearer_auth=secret,
server_url=server_url,
timeout_ms=5000,
) as clerk:
jwks = await clerk.jwks.get_async()
if jwks is None:
raise RuntimeError("Clerk JWKS response was empty.")
payload = jwks.model_dump()
if not isinstance(payload, dict):
raise RuntimeError("Clerk JWKS response had invalid shape.")
_jwks_cache_payload = payload
_jwks_cache_at_monotonic = monotonic()
return payload
def _public_key_for_kid(jwks_payload: dict[str, object], kid: str) -> jwt.PyJWK | None:
try:
jwk_set = jwt.PyJWKSet.from_dict(jwks_payload)
except jwt.PyJWTError:
return None
for key in jwk_set.keys:
if key.key_id == kid:
return key
return None
async def _decode_clerk_token(token: str) -> dict[str, object]:
try:
header = jwt.get_unverified_header(token)
except jwt.PyJWTError as exc:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from exc
kid = _non_empty_str(header.get("kid"))
if kid is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
secret_kind = settings.clerk_secret_key.strip().split("_", maxsplit=1)[0]
for attempt in (False, True):
try:
jwks_payload = await _fetch_clerk_jwks(force_refresh=attempt)
except (ClerkErrors, SDKError, RuntimeError):
logger.warning(
"auth.clerk.jwks.fetch_failed attempt=%s secret_kind=%s",
2 if attempt else 1,
secret_kind,
exc_info=True,
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from None
key = _public_key_for_kid(jwks_payload, kid)
if key is None:
if attempt:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
continue
try:
decoded = jwt.decode(
token,
key=key,
algorithms=["RS256"],
options={
"verify_aud": False,
"verify_iat": settings.clerk_verify_iat,
},
leeway=settings.clerk_leeway,
)
except jwt.PyJWTError as exc:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from exc
if not isinstance(decoded, dict):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
return {str(k): v for k, v in decoded.items()}
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
async def _fetch_clerk_profile(clerk_user_id: str) -> tuple[str | None, str | None]:
secret = settings.clerk_secret_key.strip()
secret_kind = secret.split("_", maxsplit=1)[0] if "_" in secret else "unknown"
server_url = _normalize_clerk_server_url(settings.clerk_api_url or "")
try:
async with Clerk(
bearer_auth=secret,
server_url=server_url,
timeout_ms=5000,
) as clerk:
profile = await clerk.users.get_async(user_id=clerk_user_id)
email, name = _extract_clerk_profile(profile)
logger.info(
"auth.clerk.profile.fetch clerk_user_id=%s email=%s name=%s",
clerk_user_id,
email,
name,
)
return email, name
except ClerkErrors as exc:
errors_payload = str(exc)
if len(errors_payload) > 300:
errors_payload = f"{errors_payload[:300]}..."
logger.warning(
"auth.clerk.profile.fetch_failed clerk_user_id=%s reason=clerk_errors "
"secret_kind=%s body=%s",
clerk_user_id,
secret_kind,
errors_payload,
)
except SDKError as exc:
response_body = exc.body.strip() or None
if response_body and len(response_body) > 300:
response_body = f"{response_body[:300]}..."
logger.warning(
"auth.clerk.profile.fetch_failed clerk_user_id=%s status=%s reason=sdk_error "
"server_url=%s secret_kind=%s body=%s",
clerk_user_id,
exc.status_code,
server_url,
secret_kind,
response_body,
)
except Exception:
logger.warning(
"auth.clerk.profile.fetch_failed clerk_user_id=%s reason=sdk_exception",
clerk_user_id,
exc_info=True,
)
return None, None
async def _get_or_sync_user(
session: AsyncSession,
*,
clerk_user_id: str,
claims: dict[str, object],
) -> User:
email, name = await _fetch_clerk_profile(clerk_user_id)
logger.info(
"auth.claims.parsed clerk_user_id=%s extracted_email=%s extracted_name=%s claims=%s",
clerk_user_id,
email,
name,
_claim_debug_snapshot(claims),
)
defaults: dict[str, object | None] = {
"email": email,
"name": name,
}
user, _created = await crud.get_or_create(
session,
User,
clerk_user_id=clerk_user_id,
defaults=defaults,
)
changed = False
if email and user.email != email:
user.email = email
changed = True
if not user.name and name:
user.name = name
changed = True
if changed:
session.add(user)
await session.commit()
await session.refresh(user)
logger.info(
"auth.user.sync clerk_user_id=%s updated=%s claim_email=%s final_email=%s",
clerk_user_id,
changed,
_normalize_email(defaults.get("email")),
_normalize_email(user.email),
)
if not user.email:
logger.warning(
"auth.user.sync.missing_email clerk_user_id=%s claims=%s",
clerk_user_id,
_claim_debug_snapshot(claims),
)
return user
def _parse_subject(claims: dict[str, object]) -> str | None:
payload = ClerkTokenPayload.model_validate(claims)
return payload.sub
async def get_auth_context(
request: Request,
credentials: HTTPAuthorizationCredentials | None = SECURITY_DEP,
session: AsyncSession = SESSION_DEP,
) -> AuthContext:
@@ -79,37 +397,18 @@ async def get_auth_context(
if credentials is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
claims = await _decode_clerk_token(credentials.credentials)
try:
guard = _build_clerk_http_bearer(auto_error=False)
clerk_credentials = await guard(request)
except (RuntimeError, ValueError) as exc:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) from exc
except HTTPException as exc:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from exc
auth_data = _resolve_clerk_auth(request, clerk_credentials)
try:
clerk_user_id = _parse_subject(auth_data)
clerk_user_id = _parse_subject(claims)
except ValidationError as exc:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from exc
if not clerk_user_id:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
claims: dict[str, object] = {}
if auth_data and auth_data.decoded:
claims = auth_data.decoded
email_obj = claims.get("email")
name_obj = claims.get("name")
defaults: dict[str, object | None] = {
"email": email_obj if isinstance(email_obj, str) else None,
"name": name_obj if isinstance(name_obj, str) else None,
}
user, _created = await crud.get_or_create(
user = await _get_or_sync_user(
session,
User,
clerk_user_id=clerk_user_id,
defaults=defaults,
claims=claims,
)
from app.services.organizations import ensure_member_for_user
@@ -133,36 +432,21 @@ async def get_auth_context_optional(
return None
try:
guard = _build_clerk_http_bearer(auto_error=False)
clerk_credentials = await guard(request)
except (RuntimeError, ValueError) as exc:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) from exc
claims = await _decode_clerk_token(credentials.credentials)
except HTTPException:
return None
auth_data = _resolve_clerk_auth(request, clerk_credentials)
try:
clerk_user_id = _parse_subject(auth_data)
clerk_user_id = _parse_subject(claims)
except ValidationError:
return None
if not clerk_user_id:
return None
claims: dict[str, object] = {}
if auth_data and auth_data.decoded:
claims = auth_data.decoded
email_obj = claims.get("email")
name_obj = claims.get("name")
defaults: dict[str, object | None] = {
"email": email_obj if isinstance(email_obj, str) else None,
"name": name_obj if isinstance(name_obj, str) else None,
}
user, _created = await crud.get_or_create(
user = await _get_or_sync_user(
session,
User,
clerk_user_id=clerk_user_id,
defaults=defaults,
claims=claims,
)
from app.services.organizations import ensure_member_for_user

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
from pathlib import Path
from typing import Self
from pydantic import model_validator
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
BACKEND_ROOT = Path(__file__).resolve().parents[2]
@@ -28,7 +28,8 @@ class Settings(BaseSettings):
redis_url: str = "redis://localhost:6379/0"
# Clerk auth (auth only; roles stored in DB)
clerk_jwks_url: str = ""
clerk_secret_key: str = Field(min_length=1)
clerk_api_url: str = "https://api.clerk.com"
clerk_verify_iat: bool = True
clerk_leeway: float = 10.0
@@ -52,6 +53,8 @@ class Settings(BaseSettings):
@model_validator(mode="after")
def _defaults(self) -> Self:
if not self.clerk_secret_key.strip():
raise ValueError("CLERK_SECRET_KEY must be set and non-empty.")
# In dev, default to applying Alembic migrations at startup to avoid
# schema drift (e.g. missing newly-added columns).
if "db_auto_migrate" not in self.model_fields_set and self.environment == "dev":

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
from datetime import datetime
from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint
from sqlmodel import Field
from app.core.time import utcnow
@@ -18,7 +17,6 @@ class Organization(QueryModel, table=True):
"""Top-level organization tenant record."""
__tablename__ = "organizations" # pyright: ignore[reportAssignmentType]
__table_args__ = (UniqueConstraint("name", name="uq_organizations_name"),)
id: UUID = Field(default_factory=uuid4, primary_key=True)
name: str = Field(index=True)

View File

@@ -6,7 +6,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterable
from fastapi import HTTPException, status
from sqlalchemy import func, or_
from sqlalchemy import or_
from sqlmodel import col, select
from app.core.time import utcnow
@@ -48,23 +48,6 @@ def is_org_admin(member: OrganizationMember) -> bool:
return member.role in ADMIN_ROLES
async def get_default_org(session: AsyncSession) -> Organization | None:
"""Return the default personal organization if it exists."""
return await Organization.objects.filter_by(name=DEFAULT_ORG_NAME).first(session)
async def ensure_default_org(session: AsyncSession) -> Organization:
"""Ensure and return the default personal organization."""
org = await get_default_org(session)
if org is not None:
return org
org = Organization(name=DEFAULT_ORG_NAME, created_at=utcnow(), updated_at=utcnow())
session.add(org)
await session.commit()
await session.refresh(org)
return org
async def get_member(
session: AsyncSession,
*,
@@ -216,31 +199,41 @@ async def ensure_member_for_user(
if existing is not None:
return existing
# Serialize first-time provisioning per user to avoid concurrent duplicate org/member creation.
await session.exec(
select(User.id)
.where(col(User.id) == user.id)
.with_for_update(),
)
existing_member = await get_first_membership(session, user.id)
if existing_member is not None:
if user.active_organization_id != existing_member.organization_id:
user.active_organization_id = existing_member.organization_id
session.add(user)
await session.commit()
return existing_member
if user.email:
invite = await _find_pending_invite(session, user.email)
if invite is not None:
return await accept_invite(session, invite, user)
org = await ensure_default_org(session)
now = utcnow()
member_count = (
await session.exec(
select(func.count()).where(
col(OrganizationMember.organization_id) == org.id,
),
)
).one()
is_first = int(member_count or 0) == 0
org = Organization(name=DEFAULT_ORG_NAME, created_at=now, updated_at=now)
session.add(org)
await session.flush()
org_id = org.id
member = OrganizationMember(
organization_id=org.id,
organization_id=org_id,
user_id=user.id,
role="owner" if is_first else "member",
all_boards_read=is_first,
all_boards_write=is_first,
role="owner",
all_boards_read=True,
all_boards_write=True,
created_at=now,
updated_at=now,
)
user.active_organization_id = org.id
user.active_organization_id = org_id
session.add(user)
session.add(member)
await session.commit()