from __future__ import annotations from dataclasses import dataclass from typing import Iterable from uuid import UUID from fastapi import HTTPException, status from sqlalchemy import delete, func, or_ from sqlalchemy.sql.elements import ColumnElement from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession from app.core.time import utcnow from app.models.boards import Board from app.models.organization_board_access import OrganizationBoardAccess from app.models.organization_invite_board_access import OrganizationInviteBoardAccess from app.models.organization_invites import OrganizationInvite from app.models.organization_members import OrganizationMember from app.models.organizations import Organization from app.models.users import User from app.schemas.organizations import OrganizationBoardAccessSpec, OrganizationMemberAccessUpdate DEFAULT_ORG_NAME = "Personal" ADMIN_ROLES = {"owner", "admin"} ROLE_RANK = {"member": 0, "admin": 1, "owner": 2} @dataclass(frozen=True) class OrganizationContext: organization: Organization member: OrganizationMember def is_org_admin(member: OrganizationMember) -> bool: return member.role in ADMIN_ROLES async def get_default_org(session: AsyncSession) -> Organization | None: statement = select(Organization).where(col(Organization.name) == DEFAULT_ORG_NAME) return (await session.exec(statement)).first() async def ensure_default_org(session: AsyncSession) -> Organization: org = await get_default_org(session) if org is not None: return org org = Organization(name=DEFAULT_ORG_NAME, created_at=utcnow(), updated_at=utcnow()) session.add(org) await session.commit() await session.refresh(org) return org async def get_member( session: AsyncSession, *, user_id: UUID, organization_id: UUID, ) -> OrganizationMember | None: statement = select(OrganizationMember).where( col(OrganizationMember.organization_id) == organization_id, col(OrganizationMember.user_id) == user_id, ) return (await session.exec(statement)).first() async def get_first_membership(session: AsyncSession, user_id: UUID) -> OrganizationMember | None: statement = ( select(OrganizationMember) .where(col(OrganizationMember.user_id) == user_id) .order_by(col(OrganizationMember.created_at).asc()) ) return (await session.exec(statement)).first() async def set_active_organization( session: AsyncSession, *, user: User, organization_id: UUID, ) -> OrganizationMember: member = await get_member(session, user_id=user.id, organization_id=organization_id) if member is None: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No org access") if user.active_organization_id != organization_id: user.active_organization_id = organization_id session.add(user) await session.commit() return member async def get_active_membership( session: AsyncSession, user: User, ) -> OrganizationMember | None: db_user = await session.get(User, user.id) if db_user is None: db_user = user if db_user.active_organization_id: member = await get_member( session, user_id=db_user.id, organization_id=db_user.active_organization_id, ) if member is not None: user.active_organization_id = db_user.active_organization_id return member db_user.active_organization_id = None session.add(db_user) await session.commit() member = await get_first_membership(session, db_user.id) if member is None: return None await set_active_organization( session, user=db_user, organization_id=member.organization_id, ) user.active_organization_id = db_user.active_organization_id return member async def _find_pending_invite( session: AsyncSession, email: str, ) -> OrganizationInvite | None: statement = ( select(OrganizationInvite) .where(col(OrganizationInvite.accepted_at).is_(None)) .where(col(OrganizationInvite.invited_email) == email) .order_by(col(OrganizationInvite.created_at).asc()) ) return (await session.exec(statement)).first() async def accept_invite( session: AsyncSession, invite: OrganizationInvite, user: User, ) -> OrganizationMember: now = utcnow() member = OrganizationMember( organization_id=invite.organization_id, user_id=user.id, role=invite.role, all_boards_read=invite.all_boards_read, all_boards_write=invite.all_boards_write, created_at=now, updated_at=now, ) session.add(member) await session.flush() if not (invite.all_boards_read or invite.all_boards_write): access_rows = list( await session.exec( select(OrganizationInviteBoardAccess).where( col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id ) ) ) for row in access_rows: session.add( OrganizationBoardAccess( organization_member_id=member.id, board_id=row.board_id, can_read=row.can_read, can_write=row.can_write, created_at=now, updated_at=now, ) ) invite.accepted_by_user_id = user.id invite.accepted_at = now invite.updated_at = now session.add(invite) if user.active_organization_id is None: user.active_organization_id = invite.organization_id session.add(user) await session.commit() await session.refresh(member) return member async def ensure_member_for_user(session: AsyncSession, user: User) -> OrganizationMember: existing = await get_active_membership(session, user) if existing is not None: return existing if user.email: invite = await _find_pending_invite(session, user.email) if invite is not None: return await accept_invite(session, invite, user) org = await ensure_default_org(session) now = utcnow() member_count = ( await session.exec( select(func.count()).where(col(OrganizationMember.organization_id) == org.id) ) ).one() is_first = int(member_count or 0) == 0 member = OrganizationMember( organization_id=org.id, user_id=user.id, role="owner" if is_first else "member", all_boards_read=is_first, all_boards_write=is_first, created_at=now, updated_at=now, ) user.active_organization_id = org.id session.add(user) session.add(member) await session.commit() await session.refresh(member) return member def member_all_boards_read(member: OrganizationMember) -> bool: return member.all_boards_read or member.all_boards_write def member_all_boards_write(member: OrganizationMember) -> bool: return member.all_boards_write async def has_board_access( session: AsyncSession, *, member: OrganizationMember, board: Board, write: bool, ) -> bool: if member.organization_id != board.organization_id: return False if write: if member_all_boards_write(member): return True else: if member_all_boards_read(member): return True statement = select(OrganizationBoardAccess).where( col(OrganizationBoardAccess.organization_member_id) == member.id, col(OrganizationBoardAccess.board_id) == board.id, ) access = (await session.exec(statement)).first() if access is None: return False if write: return bool(access.can_write) return bool(access.can_read or access.can_write) async def require_board_access( session: AsyncSession, *, user: User, board: Board, write: bool, ) -> OrganizationMember: member = await get_member(session, user_id=user.id, organization_id=board.organization_id) if member is None: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No org access") if not await has_board_access(session, member=member, board=board, write=write): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Board access denied") return member def board_access_filter(member: OrganizationMember, *, write: bool) -> ColumnElement[bool]: if write and member_all_boards_write(member): return col(Board.organization_id) == member.organization_id if not write and member_all_boards_read(member): return col(Board.organization_id) == member.organization_id access_stmt = select(OrganizationBoardAccess.board_id).where( col(OrganizationBoardAccess.organization_member_id) == member.id ) if write: access_stmt = access_stmt.where(col(OrganizationBoardAccess.can_write).is_(True)) else: access_stmt = access_stmt.where( or_( col(OrganizationBoardAccess.can_read).is_(True), col(OrganizationBoardAccess.can_write).is_(True), ) ) return col(Board.id).in_(access_stmt) async def list_accessible_board_ids( session: AsyncSession, *, member: OrganizationMember, write: bool, ) -> list[UUID]: if (write and member_all_boards_write(member)) or ( not write and member_all_boards_read(member) ): ids = await session.exec( select(Board.id).where(col(Board.organization_id) == member.organization_id) ) return list(ids) access_stmt = select(OrganizationBoardAccess.board_id).where( col(OrganizationBoardAccess.organization_member_id) == member.id ) if write: access_stmt = access_stmt.where(col(OrganizationBoardAccess.can_write).is_(True)) else: access_stmt = access_stmt.where( or_( col(OrganizationBoardAccess.can_read).is_(True), col(OrganizationBoardAccess.can_write).is_(True), ) ) board_ids = await session.exec(access_stmt) return list(board_ids) async def apply_member_access_update( session: AsyncSession, *, member: OrganizationMember, update: OrganizationMemberAccessUpdate, ) -> None: now = utcnow() member.all_boards_read = update.all_boards_read member.all_boards_write = update.all_boards_write member.updated_at = now session.add(member) await session.execute( delete(OrganizationBoardAccess).where( col(OrganizationBoardAccess.organization_member_id) == member.id ), ) if update.all_boards_read or update.all_boards_write: return rows: list[OrganizationBoardAccess] = [] for entry in update.board_access: rows.append( OrganizationBoardAccess( organization_member_id=member.id, board_id=entry.board_id, can_read=entry.can_read, can_write=entry.can_write, created_at=now, updated_at=now, ) ) session.add_all(rows) async def apply_invite_board_access( session: AsyncSession, *, invite: OrganizationInvite, entries: Iterable[OrganizationBoardAccessSpec], ) -> None: await session.execute( delete(OrganizationInviteBoardAccess).where( col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id ), ) if invite.all_boards_read or invite.all_boards_write: return now = utcnow() rows: list[OrganizationInviteBoardAccess] = [] for entry in entries: rows.append( OrganizationInviteBoardAccess( organization_invite_id=invite.id, board_id=entry.board_id, can_read=entry.can_read, can_write=entry.can_write, created_at=now, updated_at=now, ) ) session.add_all(rows) def normalize_invited_email(email: str) -> str: return email.strip().lower() def normalize_role(role: str) -> str: return role.strip().lower() or "member" def _role_rank(role: str | None) -> int: if not role: return 0 return ROLE_RANK.get(role, 0) async def apply_invite_to_member( session: AsyncSession, *, member: OrganizationMember, invite: OrganizationInvite, ) -> None: now = utcnow() member_changed = False invite_role = normalize_role(invite.role or "member") if _role_rank(invite_role) > _role_rank(member.role): member.role = invite_role member_changed = True if invite.all_boards_read or invite.all_boards_write: member.all_boards_read = ( member.all_boards_read or invite.all_boards_read or invite.all_boards_write ) member.all_boards_write = member.all_boards_write or invite.all_boards_write member_changed = True if member_changed: member.updated_at = now session.add(member) return access_rows = list( await session.exec( select(OrganizationInviteBoardAccess).where( col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id ) ) ) for row in access_rows: existing = ( await session.exec( select(OrganizationBoardAccess).where( col(OrganizationBoardAccess.organization_member_id) == member.id, col(OrganizationBoardAccess.board_id) == row.board_id, ) ) ).first() can_write = bool(row.can_write) can_read = bool(row.can_read or row.can_write) if existing is None: session.add( OrganizationBoardAccess( organization_member_id=member.id, board_id=row.board_id, can_read=can_read, can_write=can_write, created_at=now, updated_at=now, ) ) else: existing.can_read = bool(existing.can_read or can_read) existing.can_write = bool(existing.can_write or can_write) existing.updated_at = now session.add(existing) if member_changed: member.updated_at = now session.add(member)