refactor: update Clerk authentication integration and improve organization handling
This commit is contained in:
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlmodel import SQLModel, col, select
|
||||
@@ -64,7 +65,6 @@ from app.services.task_dependencies import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi_pagination.limit_offset import LimitOffsetPage
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -2,14 +2,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from time import monotonic
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import jwt
|
||||
from clerk_backend_api import Clerk
|
||||
from clerk_backend_api.models.clerkerrors import ClerkErrors
|
||||
from clerk_backend_api.models.sdkerror import SDKError
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from fastapi_clerk_auth import ClerkConfig, ClerkHTTPBearer
|
||||
from fastapi_clerk_auth import HTTPAuthorizationCredentials as ClerkCredentials
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from app.core.config import settings
|
||||
@@ -18,12 +21,16 @@ from app.db.session import get_session
|
||||
from app.models.users import User
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from clerk_backend_api.models.user import User as ClerkUser
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
security = HTTPBearer(auto_error=False)
|
||||
SECURITY_DEP = Depends(security)
|
||||
SESSION_DEP = Depends(get_session)
|
||||
CLERK_JWKS_URL_REQUIRED_ERROR = "CLERK_JWKS_URL is not set."
|
||||
_JWKS_CACHE_TTL_SECONDS = 300.0
|
||||
_jwks_cache_payload: dict[str, object] | None = None
|
||||
_jwks_cache_at_monotonic = 0.0
|
||||
|
||||
|
||||
class ClerkTokenPayload(BaseModel):
|
||||
@@ -32,19 +39,6 @@ class ClerkTokenPayload(BaseModel):
|
||||
sub: str
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _build_clerk_http_bearer(*, auto_error: bool) -> ClerkHTTPBearer:
|
||||
"""Create and cache the Clerk HTTP bearer guard."""
|
||||
if not settings.clerk_jwks_url:
|
||||
raise RuntimeError(CLERK_JWKS_URL_REQUIRED_ERROR)
|
||||
clerk_config = ClerkConfig(
|
||||
jwks_url=settings.clerk_jwks_url,
|
||||
verify_iat=settings.clerk_verify_iat,
|
||||
leeway=settings.clerk_leeway,
|
||||
)
|
||||
return ClerkHTTPBearer(config=clerk_config, auto_error=auto_error, add_state=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthContext:
|
||||
"""Authenticated user context resolved from inbound auth headers."""
|
||||
@@ -53,25 +47,349 @@ class AuthContext:
|
||||
user: User | None = None
|
||||
|
||||
|
||||
def _resolve_clerk_auth(
|
||||
request: Request,
|
||||
fallback: ClerkCredentials | None,
|
||||
) -> ClerkCredentials | None:
|
||||
auth_data = getattr(request.state, "clerk_auth", None)
|
||||
if isinstance(auth_data, ClerkCredentials):
|
||||
return auth_data
|
||||
return fallback
|
||||
|
||||
|
||||
def _parse_subject(auth_data: ClerkCredentials | None) -> str | None:
|
||||
if not auth_data or not auth_data.decoded:
|
||||
def _non_empty_str(value: object) -> str | None:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
payload = ClerkTokenPayload.model_validate(auth_data.decoded)
|
||||
cleaned = value.strip()
|
||||
return cleaned or None
|
||||
|
||||
|
||||
def _normalize_email(value: object) -> str | None:
|
||||
text = _non_empty_str(value)
|
||||
if text is None:
|
||||
return None
|
||||
return text.lower()
|
||||
|
||||
|
||||
def _extract_claim_email(claims: dict[str, object]) -> str | None:
|
||||
for key in ("email", "email_address", "primary_email_address"):
|
||||
email = _normalize_email(claims.get(key))
|
||||
if email:
|
||||
return email
|
||||
|
||||
primary_email_id = _non_empty_str(claims.get("primary_email_address_id"))
|
||||
email_addresses = claims.get("email_addresses")
|
||||
if not isinstance(email_addresses, list):
|
||||
return None
|
||||
|
||||
fallback_email: str | None = None
|
||||
for item in email_addresses:
|
||||
if isinstance(item, str):
|
||||
normalized = _normalize_email(item)
|
||||
if normalized and fallback_email is None:
|
||||
fallback_email = normalized
|
||||
continue
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
candidate = _normalize_email(item.get("email_address") or item.get("email"))
|
||||
if not candidate:
|
||||
continue
|
||||
candidate_id = _non_empty_str(item.get("id"))
|
||||
if primary_email_id and candidate_id == primary_email_id:
|
||||
return candidate
|
||||
if fallback_email is None:
|
||||
fallback_email = candidate
|
||||
return fallback_email
|
||||
|
||||
|
||||
def _extract_claim_name(claims: dict[str, object]) -> str | None:
|
||||
for key in ("name", "full_name"):
|
||||
text = _non_empty_str(claims.get(key))
|
||||
if text:
|
||||
return text
|
||||
|
||||
first = _non_empty_str(claims.get("given_name")) or _non_empty_str(claims.get("first_name"))
|
||||
last = _non_empty_str(claims.get("family_name")) or _non_empty_str(claims.get("last_name"))
|
||||
parts = [part for part in (first, last) if part]
|
||||
if not parts:
|
||||
return None
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _claim_debug_snapshot(claims: dict[str, object]) -> dict[str, object]:
|
||||
email_addresses = claims.get("email_addresses")
|
||||
email_samples: list[dict[str, str | None]] = []
|
||||
if isinstance(email_addresses, list):
|
||||
for item in email_addresses[:5]:
|
||||
if isinstance(item, dict):
|
||||
email_samples.append(
|
||||
{
|
||||
"id": _non_empty_str(item.get("id")),
|
||||
"email": _normalize_email(
|
||||
item.get("email_address") or item.get("email"),
|
||||
),
|
||||
},
|
||||
)
|
||||
elif isinstance(item, str):
|
||||
email_samples.append({"id": None, "email": _normalize_email(item)})
|
||||
|
||||
return {
|
||||
"keys": sorted(claims.keys()),
|
||||
"iss": _non_empty_str(claims.get("iss")),
|
||||
"sub": _non_empty_str(claims.get("sub")),
|
||||
"email": _normalize_email(claims.get("email")),
|
||||
"email_address": _normalize_email(claims.get("email_address")),
|
||||
"primary_email_address": _normalize_email(claims.get("primary_email_address")),
|
||||
"primary_email_address_id": _non_empty_str(claims.get("primary_email_address_id")),
|
||||
"email_addresses_count": len(email_addresses) if isinstance(email_addresses, list) else 0,
|
||||
"email_addresses_sample": email_samples,
|
||||
"name": _non_empty_str(claims.get("name")),
|
||||
"full_name": _non_empty_str(claims.get("full_name")),
|
||||
"given_name": _non_empty_str(claims.get("given_name"))
|
||||
or _non_empty_str(claims.get("first_name")),
|
||||
"family_name": _non_empty_str(claims.get("family_name"))
|
||||
or _non_empty_str(claims.get("last_name")),
|
||||
}
|
||||
|
||||
|
||||
def _extract_clerk_profile(profile: ClerkUser | None) -> tuple[str | None, str | None]:
|
||||
if profile is None:
|
||||
return None, None
|
||||
|
||||
profile_email = _normalize_email(getattr(profile, "email_address", None))
|
||||
primary_email_id = _non_empty_str(getattr(profile, "primary_email_address_id", None))
|
||||
emails = getattr(profile, "email_addresses", None)
|
||||
if not profile_email and isinstance(emails, list):
|
||||
fallback_email: str | None = None
|
||||
for item in emails:
|
||||
candidate = _normalize_email(
|
||||
getattr(item, "email_address", None),
|
||||
)
|
||||
if not candidate:
|
||||
continue
|
||||
candidate_id = _non_empty_str(getattr(item, "id", None))
|
||||
if primary_email_id and candidate_id == primary_email_id:
|
||||
profile_email = candidate
|
||||
break
|
||||
if fallback_email is None:
|
||||
fallback_email = candidate
|
||||
if profile_email is None:
|
||||
profile_email = fallback_email
|
||||
|
||||
profile_name = (
|
||||
_non_empty_str(getattr(profile, "full_name", None))
|
||||
or _non_empty_str(getattr(profile, "name", None))
|
||||
or _non_empty_str(getattr(profile, "first_name", None))
|
||||
or _non_empty_str(getattr(profile, "username", None))
|
||||
)
|
||||
if not profile_name:
|
||||
first = _non_empty_str(getattr(profile, "first_name", None))
|
||||
last = _non_empty_str(getattr(profile, "last_name", None))
|
||||
parts = [part for part in (first, last) if part]
|
||||
if parts:
|
||||
profile_name = " ".join(parts)
|
||||
|
||||
return profile_email, profile_name
|
||||
|
||||
|
||||
def _normalize_clerk_server_url(raw: str) -> str | None:
|
||||
server_url = raw.strip().rstrip("/")
|
||||
if not server_url:
|
||||
return None
|
||||
if not server_url.endswith("/v1"):
|
||||
server_url = f"{server_url}/v1"
|
||||
return server_url
|
||||
|
||||
|
||||
async def _fetch_clerk_jwks(*, force_refresh: bool = False) -> dict[str, object]:
|
||||
global _jwks_cache_payload
|
||||
global _jwks_cache_at_monotonic
|
||||
|
||||
if (
|
||||
not force_refresh
|
||||
and _jwks_cache_payload is not None
|
||||
and monotonic() - _jwks_cache_at_monotonic < _JWKS_CACHE_TTL_SECONDS
|
||||
):
|
||||
return _jwks_cache_payload
|
||||
|
||||
secret = settings.clerk_secret_key.strip()
|
||||
server_url = _normalize_clerk_server_url(settings.clerk_api_url or "")
|
||||
async with Clerk(
|
||||
bearer_auth=secret,
|
||||
server_url=server_url,
|
||||
timeout_ms=5000,
|
||||
) as clerk:
|
||||
jwks = await clerk.jwks.get_async()
|
||||
if jwks is None:
|
||||
raise RuntimeError("Clerk JWKS response was empty.")
|
||||
payload = jwks.model_dump()
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError("Clerk JWKS response had invalid shape.")
|
||||
_jwks_cache_payload = payload
|
||||
_jwks_cache_at_monotonic = monotonic()
|
||||
return payload
|
||||
|
||||
|
||||
def _public_key_for_kid(jwks_payload: dict[str, object], kid: str) -> jwt.PyJWK | None:
|
||||
try:
|
||||
jwk_set = jwt.PyJWKSet.from_dict(jwks_payload)
|
||||
except jwt.PyJWTError:
|
||||
return None
|
||||
for key in jwk_set.keys:
|
||||
if key.key_id == kid:
|
||||
return key
|
||||
return None
|
||||
|
||||
|
||||
async def _decode_clerk_token(token: str) -> dict[str, object]:
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
except jwt.PyJWTError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from exc
|
||||
|
||||
kid = _non_empty_str(header.get("kid"))
|
||||
if kid is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
secret_kind = settings.clerk_secret_key.strip().split("_", maxsplit=1)[0]
|
||||
for attempt in (False, True):
|
||||
try:
|
||||
jwks_payload = await _fetch_clerk_jwks(force_refresh=attempt)
|
||||
except (ClerkErrors, SDKError, RuntimeError):
|
||||
logger.warning(
|
||||
"auth.clerk.jwks.fetch_failed attempt=%s secret_kind=%s",
|
||||
2 if attempt else 1,
|
||||
secret_kind,
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from None
|
||||
|
||||
key = _public_key_for_kid(jwks_payload, kid)
|
||||
if key is None:
|
||||
if attempt:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
continue
|
||||
try:
|
||||
decoded = jwt.decode(
|
||||
token,
|
||||
key=key,
|
||||
algorithms=["RS256"],
|
||||
options={
|
||||
"verify_aud": False,
|
||||
"verify_iat": settings.clerk_verify_iat,
|
||||
},
|
||||
leeway=settings.clerk_leeway,
|
||||
)
|
||||
except jwt.PyJWTError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from exc
|
||||
if not isinstance(decoded, dict):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
return {str(k): v for k, v in decoded.items()}
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
async def _fetch_clerk_profile(clerk_user_id: str) -> tuple[str | None, str | None]:
|
||||
secret = settings.clerk_secret_key.strip()
|
||||
secret_kind = secret.split("_", maxsplit=1)[0] if "_" in secret else "unknown"
|
||||
server_url = _normalize_clerk_server_url(settings.clerk_api_url or "")
|
||||
|
||||
try:
|
||||
async with Clerk(
|
||||
bearer_auth=secret,
|
||||
server_url=server_url,
|
||||
timeout_ms=5000,
|
||||
) as clerk:
|
||||
profile = await clerk.users.get_async(user_id=clerk_user_id)
|
||||
email, name = _extract_clerk_profile(profile)
|
||||
logger.info(
|
||||
"auth.clerk.profile.fetch clerk_user_id=%s email=%s name=%s",
|
||||
clerk_user_id,
|
||||
email,
|
||||
name,
|
||||
)
|
||||
return email, name
|
||||
except ClerkErrors as exc:
|
||||
errors_payload = str(exc)
|
||||
if len(errors_payload) > 300:
|
||||
errors_payload = f"{errors_payload[:300]}..."
|
||||
logger.warning(
|
||||
"auth.clerk.profile.fetch_failed clerk_user_id=%s reason=clerk_errors "
|
||||
"secret_kind=%s body=%s",
|
||||
clerk_user_id,
|
||||
secret_kind,
|
||||
errors_payload,
|
||||
)
|
||||
except SDKError as exc:
|
||||
response_body = exc.body.strip() or None
|
||||
if response_body and len(response_body) > 300:
|
||||
response_body = f"{response_body[:300]}..."
|
||||
logger.warning(
|
||||
"auth.clerk.profile.fetch_failed clerk_user_id=%s status=%s reason=sdk_error "
|
||||
"server_url=%s secret_kind=%s body=%s",
|
||||
clerk_user_id,
|
||||
exc.status_code,
|
||||
server_url,
|
||||
secret_kind,
|
||||
response_body,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"auth.clerk.profile.fetch_failed clerk_user_id=%s reason=sdk_exception",
|
||||
clerk_user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return None, None
|
||||
|
||||
|
||||
async def _get_or_sync_user(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
clerk_user_id: str,
|
||||
claims: dict[str, object],
|
||||
) -> User:
|
||||
email, name = await _fetch_clerk_profile(clerk_user_id)
|
||||
logger.info(
|
||||
"auth.claims.parsed clerk_user_id=%s extracted_email=%s extracted_name=%s claims=%s",
|
||||
clerk_user_id,
|
||||
email,
|
||||
name,
|
||||
_claim_debug_snapshot(claims),
|
||||
)
|
||||
defaults: dict[str, object | None] = {
|
||||
"email": email,
|
||||
"name": name,
|
||||
}
|
||||
user, _created = await crud.get_or_create(
|
||||
session,
|
||||
User,
|
||||
clerk_user_id=clerk_user_id,
|
||||
defaults=defaults,
|
||||
)
|
||||
|
||||
changed = False
|
||||
if email and user.email != email:
|
||||
user.email = email
|
||||
changed = True
|
||||
if not user.name and name:
|
||||
user.name = name
|
||||
changed = True
|
||||
if changed:
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
logger.info(
|
||||
"auth.user.sync clerk_user_id=%s updated=%s claim_email=%s final_email=%s",
|
||||
clerk_user_id,
|
||||
changed,
|
||||
_normalize_email(defaults.get("email")),
|
||||
_normalize_email(user.email),
|
||||
)
|
||||
if not user.email:
|
||||
logger.warning(
|
||||
"auth.user.sync.missing_email clerk_user_id=%s claims=%s",
|
||||
clerk_user_id,
|
||||
_claim_debug_snapshot(claims),
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def _parse_subject(claims: dict[str, object]) -> str | None:
|
||||
payload = ClerkTokenPayload.model_validate(claims)
|
||||
return payload.sub
|
||||
|
||||
|
||||
async def get_auth_context(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = SECURITY_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
) -> AuthContext:
|
||||
@@ -79,37 +397,18 @@ async def get_auth_context(
|
||||
if credentials is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
claims = await _decode_clerk_token(credentials.credentials)
|
||||
try:
|
||||
guard = _build_clerk_http_bearer(auto_error=False)
|
||||
clerk_credentials = await guard(request)
|
||||
except (RuntimeError, ValueError) as exc:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) from exc
|
||||
except HTTPException as exc:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from exc
|
||||
|
||||
auth_data = _resolve_clerk_auth(request, clerk_credentials)
|
||||
try:
|
||||
clerk_user_id = _parse_subject(auth_data)
|
||||
clerk_user_id = _parse_subject(claims)
|
||||
except ValidationError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from exc
|
||||
|
||||
if not clerk_user_id:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
claims: dict[str, object] = {}
|
||||
if auth_data and auth_data.decoded:
|
||||
claims = auth_data.decoded
|
||||
email_obj = claims.get("email")
|
||||
name_obj = claims.get("name")
|
||||
defaults: dict[str, object | None] = {
|
||||
"email": email_obj if isinstance(email_obj, str) else None,
|
||||
"name": name_obj if isinstance(name_obj, str) else None,
|
||||
}
|
||||
user, _created = await crud.get_or_create(
|
||||
user = await _get_or_sync_user(
|
||||
session,
|
||||
User,
|
||||
clerk_user_id=clerk_user_id,
|
||||
defaults=defaults,
|
||||
claims=claims,
|
||||
)
|
||||
from app.services.organizations import ensure_member_for_user
|
||||
|
||||
@@ -133,36 +432,21 @@ async def get_auth_context_optional(
|
||||
return None
|
||||
|
||||
try:
|
||||
guard = _build_clerk_http_bearer(auto_error=False)
|
||||
clerk_credentials = await guard(request)
|
||||
except (RuntimeError, ValueError) as exc:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) from exc
|
||||
claims = await _decode_clerk_token(credentials.credentials)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
auth_data = _resolve_clerk_auth(request, clerk_credentials)
|
||||
try:
|
||||
clerk_user_id = _parse_subject(auth_data)
|
||||
clerk_user_id = _parse_subject(claims)
|
||||
except ValidationError:
|
||||
return None
|
||||
|
||||
if not clerk_user_id:
|
||||
return None
|
||||
|
||||
claims: dict[str, object] = {}
|
||||
if auth_data and auth_data.decoded:
|
||||
claims = auth_data.decoded
|
||||
email_obj = claims.get("email")
|
||||
name_obj = claims.get("name")
|
||||
defaults: dict[str, object | None] = {
|
||||
"email": email_obj if isinstance(email_obj, str) else None,
|
||||
"name": name_obj if isinstance(name_obj, str) else None,
|
||||
}
|
||||
user, _created = await crud.get_or_create(
|
||||
user = await _get_or_sync_user(
|
||||
session,
|
||||
User,
|
||||
clerk_user_id=clerk_user_id,
|
||||
defaults=defaults,
|
||||
claims=claims,
|
||||
)
|
||||
from app.services.organizations import ensure_member_for_user
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
from typing import Self
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
BACKEND_ROOT = Path(__file__).resolve().parents[2]
|
||||
@@ -28,7 +28,8 @@ class Settings(BaseSettings):
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
|
||||
# Clerk auth (auth only; roles stored in DB)
|
||||
clerk_jwks_url: str = ""
|
||||
clerk_secret_key: str = Field(min_length=1)
|
||||
clerk_api_url: str = "https://api.clerk.com"
|
||||
clerk_verify_iat: bool = True
|
||||
clerk_leeway: float = 10.0
|
||||
|
||||
@@ -52,6 +53,8 @@ class Settings(BaseSettings):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _defaults(self) -> Self:
|
||||
if not self.clerk_secret_key.strip():
|
||||
raise ValueError("CLERK_SECRET_KEY must be set and non-empty.")
|
||||
# In dev, default to applying Alembic migrations at startup to avoid
|
||||
# schema drift (e.g. missing newly-added columns).
|
||||
if "db_auto_migrate" not in self.model_fields_set and self.environment == "dev":
|
||||
|
||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import UniqueConstraint
|
||||
from sqlmodel import Field
|
||||
|
||||
from app.core.time import utcnow
|
||||
@@ -18,7 +17,6 @@ class Organization(QueryModel, table=True):
|
||||
"""Top-level organization tenant record."""
|
||||
|
||||
__tablename__ = "organizations" # pyright: ignore[reportAssignmentType]
|
||||
__table_args__ = (UniqueConstraint("name", name="uq_organizations_name"),)
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
name: str = Field(index=True)
|
||||
|
||||
@@ -6,7 +6,7 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy import or_
|
||||
from sqlmodel import col, select
|
||||
|
||||
from app.core.time import utcnow
|
||||
@@ -48,23 +48,6 @@ def is_org_admin(member: OrganizationMember) -> bool:
|
||||
return member.role in ADMIN_ROLES
|
||||
|
||||
|
||||
async def get_default_org(session: AsyncSession) -> Organization | None:
|
||||
"""Return the default personal organization if it exists."""
|
||||
return await Organization.objects.filter_by(name=DEFAULT_ORG_NAME).first(session)
|
||||
|
||||
|
||||
async def ensure_default_org(session: AsyncSession) -> Organization:
|
||||
"""Ensure and return the default personal organization."""
|
||||
org = await get_default_org(session)
|
||||
if org is not None:
|
||||
return org
|
||||
org = Organization(name=DEFAULT_ORG_NAME, created_at=utcnow(), updated_at=utcnow())
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
return org
|
||||
|
||||
|
||||
async def get_member(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
@@ -216,31 +199,41 @@ async def ensure_member_for_user(
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
# Serialize first-time provisioning per user to avoid concurrent duplicate org/member creation.
|
||||
await session.exec(
|
||||
select(User.id)
|
||||
.where(col(User.id) == user.id)
|
||||
.with_for_update(),
|
||||
)
|
||||
|
||||
existing_member = await get_first_membership(session, user.id)
|
||||
if existing_member is not None:
|
||||
if user.active_organization_id != existing_member.organization_id:
|
||||
user.active_organization_id = existing_member.organization_id
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
return existing_member
|
||||
|
||||
if user.email:
|
||||
invite = await _find_pending_invite(session, user.email)
|
||||
if invite is not None:
|
||||
return await accept_invite(session, invite, user)
|
||||
|
||||
org = await ensure_default_org(session)
|
||||
now = utcnow()
|
||||
member_count = (
|
||||
await session.exec(
|
||||
select(func.count()).where(
|
||||
col(OrganizationMember.organization_id) == org.id,
|
||||
),
|
||||
)
|
||||
).one()
|
||||
is_first = int(member_count or 0) == 0
|
||||
org = Organization(name=DEFAULT_ORG_NAME, created_at=now, updated_at=now)
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
org_id = org.id
|
||||
member = OrganizationMember(
|
||||
organization_id=org.id,
|
||||
organization_id=org_id,
|
||||
user_id=user.id,
|
||||
role="owner" if is_first else "member",
|
||||
all_boards_read=is_first,
|
||||
all_boards_write=is_first,
|
||||
role="owner",
|
||||
all_boards_read=True,
|
||||
all_boards_write=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
user.active_organization_id = org.id
|
||||
user.active_organization_id = org_id
|
||||
session.add(user)
|
||||
session.add(member)
|
||||
await session.commit()
|
||||
|
||||
Reference in New Issue
Block a user