461 lines
14 KiB
Python
461 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Iterable
|
|
from uuid import UUID
|
|
|
|
from fastapi import HTTPException, status
|
|
from sqlalchemy import func, or_
|
|
from sqlmodel import col, select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.core.time import utcnow
|
|
from app.models.boards import Board
|
|
from app.models.organization_board_access import OrganizationBoardAccess
|
|
from app.models.organization_invite_board_access import OrganizationInviteBoardAccess
|
|
from app.models.organization_invites import OrganizationInvite
|
|
from app.models.organization_members import OrganizationMember
|
|
from app.models.organizations import Organization
|
|
from app.models.users import User
|
|
from app.schemas.organizations import OrganizationBoardAccessSpec, OrganizationMemberAccessUpdate
|
|
|
|
DEFAULT_ORG_NAME = "Personal"
|
|
ADMIN_ROLES = {"owner", "admin"}
|
|
ROLE_RANK = {"member": 0, "admin": 1, "owner": 2}
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class OrganizationContext:
|
|
organization: Organization
|
|
member: OrganizationMember
|
|
|
|
|
|
def is_org_admin(member: OrganizationMember) -> bool:
|
|
return member.role in ADMIN_ROLES
|
|
|
|
|
|
async def get_default_org(session: AsyncSession) -> Organization | None:
|
|
statement = select(Organization).where(col(Organization.name) == DEFAULT_ORG_NAME)
|
|
return (await session.exec(statement)).first()
|
|
|
|
|
|
async def ensure_default_org(session: AsyncSession) -> Organization:
|
|
org = await get_default_org(session)
|
|
if org is not None:
|
|
return org
|
|
org = Organization(name=DEFAULT_ORG_NAME, created_at=utcnow(), updated_at=utcnow())
|
|
session.add(org)
|
|
await session.commit()
|
|
await session.refresh(org)
|
|
return org
|
|
|
|
|
|
async def get_member(
|
|
session: AsyncSession,
|
|
*,
|
|
user_id: UUID,
|
|
organization_id: UUID,
|
|
) -> OrganizationMember | None:
|
|
statement = select(OrganizationMember).where(
|
|
col(OrganizationMember.organization_id) == organization_id,
|
|
col(OrganizationMember.user_id) == user_id,
|
|
)
|
|
return (await session.exec(statement)).first()
|
|
|
|
|
|
async def get_first_membership(session: AsyncSession, user_id: UUID) -> OrganizationMember | None:
|
|
statement = (
|
|
select(OrganizationMember)
|
|
.where(col(OrganizationMember.user_id) == user_id)
|
|
.order_by(col(OrganizationMember.created_at).asc())
|
|
)
|
|
return (await session.exec(statement)).first()
|
|
|
|
|
|
async def set_active_organization(
|
|
session: AsyncSession,
|
|
*,
|
|
user: User,
|
|
organization_id: UUID,
|
|
) -> OrganizationMember:
|
|
member = await get_member(session, user_id=user.id, organization_id=organization_id)
|
|
if member is None:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No org access")
|
|
if user.active_organization_id != organization_id:
|
|
user.active_organization_id = organization_id
|
|
session.add(user)
|
|
await session.commit()
|
|
return member
|
|
|
|
|
|
async def get_active_membership(
|
|
session: AsyncSession,
|
|
user: User,
|
|
) -> OrganizationMember | None:
|
|
db_user = await session.get(User, user.id)
|
|
if db_user is None:
|
|
db_user = user
|
|
if db_user.active_organization_id:
|
|
member = await get_member(
|
|
session,
|
|
user_id=db_user.id,
|
|
organization_id=db_user.active_organization_id,
|
|
)
|
|
if member is not None:
|
|
user.active_organization_id = db_user.active_organization_id
|
|
return member
|
|
db_user.active_organization_id = None
|
|
session.add(db_user)
|
|
await session.commit()
|
|
member = await get_first_membership(session, db_user.id)
|
|
if member is None:
|
|
return None
|
|
await set_active_organization(
|
|
session,
|
|
user=db_user,
|
|
organization_id=member.organization_id,
|
|
)
|
|
user.active_organization_id = db_user.active_organization_id
|
|
return member
|
|
|
|
|
|
async def _find_pending_invite(
|
|
session: AsyncSession,
|
|
email: str,
|
|
) -> OrganizationInvite | None:
|
|
statement = (
|
|
select(OrganizationInvite)
|
|
.where(col(OrganizationInvite.accepted_at).is_(None))
|
|
.where(col(OrganizationInvite.invited_email) == email)
|
|
.order_by(col(OrganizationInvite.created_at).asc())
|
|
)
|
|
return (await session.exec(statement)).first()
|
|
|
|
|
|
async def accept_invite(
|
|
session: AsyncSession,
|
|
invite: OrganizationInvite,
|
|
user: User,
|
|
) -> OrganizationMember:
|
|
now = utcnow()
|
|
member = OrganizationMember(
|
|
organization_id=invite.organization_id,
|
|
user_id=user.id,
|
|
role=invite.role,
|
|
all_boards_read=invite.all_boards_read,
|
|
all_boards_write=invite.all_boards_write,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
session.add(member)
|
|
await session.flush()
|
|
|
|
if not (invite.all_boards_read or invite.all_boards_write):
|
|
access_rows = list(
|
|
await session.exec(
|
|
select(OrganizationInviteBoardAccess).where(
|
|
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
|
|
)
|
|
)
|
|
)
|
|
for row in access_rows:
|
|
session.add(
|
|
OrganizationBoardAccess(
|
|
organization_member_id=member.id,
|
|
board_id=row.board_id,
|
|
can_read=row.can_read,
|
|
can_write=row.can_write,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
)
|
|
|
|
invite.accepted_by_user_id = user.id
|
|
invite.accepted_at = now
|
|
invite.updated_at = now
|
|
session.add(invite)
|
|
if user.active_organization_id is None:
|
|
user.active_organization_id = invite.organization_id
|
|
session.add(user)
|
|
await session.commit()
|
|
await session.refresh(member)
|
|
return member
|
|
|
|
|
|
async def ensure_member_for_user(session: AsyncSession, user: User) -> OrganizationMember:
|
|
existing = await get_active_membership(session, user)
|
|
if existing is not None:
|
|
return existing
|
|
|
|
if user.email:
|
|
invite = await _find_pending_invite(session, user.email)
|
|
if invite is not None:
|
|
return await accept_invite(session, invite, user)
|
|
|
|
org = await ensure_default_org(session)
|
|
now = utcnow()
|
|
member_count = (
|
|
await session.exec(
|
|
select(func.count()).where(col(OrganizationMember.organization_id) == org.id)
|
|
)
|
|
).one()
|
|
is_first = int(member_count or 0) == 0
|
|
member = OrganizationMember(
|
|
organization_id=org.id,
|
|
user_id=user.id,
|
|
role="owner" if is_first else "member",
|
|
all_boards_read=is_first,
|
|
all_boards_write=is_first,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
user.active_organization_id = org.id
|
|
session.add(user)
|
|
session.add(member)
|
|
await session.commit()
|
|
await session.refresh(member)
|
|
return member
|
|
|
|
|
|
def member_all_boards_read(member: OrganizationMember) -> bool:
|
|
return member.all_boards_read or member.all_boards_write
|
|
|
|
|
|
def member_all_boards_write(member: OrganizationMember) -> bool:
|
|
return member.all_boards_write
|
|
|
|
|
|
async def has_board_access(
|
|
session: AsyncSession,
|
|
*,
|
|
member: OrganizationMember,
|
|
board: Board,
|
|
write: bool,
|
|
) -> bool:
|
|
if member.organization_id != board.organization_id:
|
|
return False
|
|
if write:
|
|
if member_all_boards_write(member):
|
|
return True
|
|
else:
|
|
if member_all_boards_read(member):
|
|
return True
|
|
statement = select(OrganizationBoardAccess).where(
|
|
col(OrganizationBoardAccess.organization_member_id) == member.id,
|
|
col(OrganizationBoardAccess.board_id) == board.id,
|
|
)
|
|
access = (await session.exec(statement)).first()
|
|
if access is None:
|
|
return False
|
|
if write:
|
|
return bool(access.can_write)
|
|
return bool(access.can_read or access.can_write)
|
|
|
|
|
|
async def require_board_access(
|
|
session: AsyncSession,
|
|
*,
|
|
user: User,
|
|
board: Board,
|
|
write: bool,
|
|
) -> OrganizationMember:
|
|
member = await get_member(session, user_id=user.id, organization_id=board.organization_id)
|
|
if member is None:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No org access")
|
|
if not await has_board_access(session, member=member, board=board, write=write):
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Board access denied")
|
|
return member
|
|
|
|
|
|
def board_access_filter(member: OrganizationMember, *, write: bool) -> object:
|
|
if write and member_all_boards_write(member):
|
|
return col(Board.organization_id) == member.organization_id
|
|
if not write and member_all_boards_read(member):
|
|
return col(Board.organization_id) == member.organization_id
|
|
access_stmt = select(OrganizationBoardAccess.board_id).where(
|
|
col(OrganizationBoardAccess.organization_member_id) == member.id
|
|
)
|
|
if write:
|
|
access_stmt = access_stmt.where(col(OrganizationBoardAccess.can_write).is_(True))
|
|
else:
|
|
access_stmt = access_stmt.where(
|
|
or_(
|
|
col(OrganizationBoardAccess.can_read).is_(True),
|
|
col(OrganizationBoardAccess.can_write).is_(True),
|
|
)
|
|
)
|
|
return col(Board.id).in_(access_stmt)
|
|
|
|
|
|
async def list_accessible_board_ids(
|
|
session: AsyncSession,
|
|
*,
|
|
member: OrganizationMember,
|
|
write: bool,
|
|
) -> list[UUID]:
|
|
if (write and member_all_boards_write(member)) or (
|
|
not write and member_all_boards_read(member)
|
|
):
|
|
ids = await session.exec(
|
|
select(Board.id).where(col(Board.organization_id) == member.organization_id)
|
|
)
|
|
return list(ids)
|
|
|
|
access_stmt = select(OrganizationBoardAccess.board_id).where(
|
|
col(OrganizationBoardAccess.organization_member_id) == member.id
|
|
)
|
|
if write:
|
|
access_stmt = access_stmt.where(col(OrganizationBoardAccess.can_write).is_(True))
|
|
else:
|
|
access_stmt = access_stmt.where(
|
|
or_(
|
|
col(OrganizationBoardAccess.can_read).is_(True),
|
|
col(OrganizationBoardAccess.can_write).is_(True),
|
|
)
|
|
)
|
|
board_ids = await session.exec(access_stmt)
|
|
return list(board_ids)
|
|
|
|
|
|
async def apply_member_access_update(
|
|
session: AsyncSession,
|
|
*,
|
|
member: OrganizationMember,
|
|
update: OrganizationMemberAccessUpdate,
|
|
) -> None:
|
|
now = utcnow()
|
|
member.all_boards_read = update.all_boards_read
|
|
member.all_boards_write = update.all_boards_write
|
|
member.updated_at = now
|
|
session.add(member)
|
|
|
|
await session.execute(
|
|
OrganizationBoardAccess.__table__.delete().where(
|
|
col(OrganizationBoardAccess.organization_member_id) == member.id
|
|
)
|
|
)
|
|
|
|
if update.all_boards_read or update.all_boards_write:
|
|
return
|
|
|
|
rows: list[OrganizationBoardAccess] = []
|
|
for entry in update.board_access:
|
|
rows.append(
|
|
OrganizationBoardAccess(
|
|
organization_member_id=member.id,
|
|
board_id=entry.board_id,
|
|
can_read=entry.can_read,
|
|
can_write=entry.can_write,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
)
|
|
session.add_all(rows)
|
|
|
|
|
|
async def apply_invite_board_access(
|
|
session: AsyncSession,
|
|
*,
|
|
invite: OrganizationInvite,
|
|
entries: Iterable[OrganizationBoardAccessSpec],
|
|
) -> None:
|
|
await session.execute(
|
|
OrganizationInviteBoardAccess.__table__.delete().where(
|
|
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
|
|
)
|
|
)
|
|
if invite.all_boards_read or invite.all_boards_write:
|
|
return
|
|
now = utcnow()
|
|
rows: list[OrganizationInviteBoardAccess] = []
|
|
for entry in entries:
|
|
rows.append(
|
|
OrganizationInviteBoardAccess(
|
|
organization_invite_id=invite.id,
|
|
board_id=entry.board_id,
|
|
can_read=entry.can_read,
|
|
can_write=entry.can_write,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
)
|
|
session.add_all(rows)
|
|
|
|
|
|
def normalize_invited_email(email: str) -> str:
|
|
return email.strip().lower()
|
|
|
|
|
|
def normalize_role(role: str) -> str:
|
|
return role.strip().lower() or "member"
|
|
|
|
|
|
def _role_rank(role: str | None) -> int:
|
|
if not role:
|
|
return 0
|
|
return ROLE_RANK.get(role, 0)
|
|
|
|
|
|
async def apply_invite_to_member(
|
|
session: AsyncSession,
|
|
*,
|
|
member: OrganizationMember,
|
|
invite: OrganizationInvite,
|
|
) -> None:
|
|
now = utcnow()
|
|
member_changed = False
|
|
invite_role = normalize_role(invite.role or "member")
|
|
if _role_rank(invite_role) > _role_rank(member.role):
|
|
member.role = invite_role
|
|
member_changed = True
|
|
|
|
if invite.all_boards_read or invite.all_boards_write:
|
|
member.all_boards_read = (
|
|
member.all_boards_read or invite.all_boards_read or invite.all_boards_write
|
|
)
|
|
member.all_boards_write = member.all_boards_write or invite.all_boards_write
|
|
member_changed = True
|
|
if member_changed:
|
|
member.updated_at = now
|
|
session.add(member)
|
|
return
|
|
|
|
access_rows = list(
|
|
await session.exec(
|
|
select(OrganizationInviteBoardAccess).where(
|
|
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
|
|
)
|
|
)
|
|
)
|
|
for row in access_rows:
|
|
existing = (
|
|
await session.exec(
|
|
select(OrganizationBoardAccess).where(
|
|
col(OrganizationBoardAccess.organization_member_id) == member.id,
|
|
col(OrganizationBoardAccess.board_id) == row.board_id,
|
|
)
|
|
)
|
|
).first()
|
|
can_write = bool(row.can_write)
|
|
can_read = bool(row.can_read or row.can_write)
|
|
if existing is None:
|
|
session.add(
|
|
OrganizationBoardAccess(
|
|
organization_member_id=member.id,
|
|
board_id=row.board_id,
|
|
can_read=can_read,
|
|
can_write=can_write,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
)
|
|
else:
|
|
existing.can_read = bool(existing.can_read or can_read)
|
|
existing.can_write = bool(existing.can_write or can_write)
|
|
existing.updated_at = now
|
|
session.add(existing)
|
|
|
|
if member_changed:
|
|
member.updated_at = now
|
|
session.add(member)
|