refactor: update module docstrings for clarity and consistency
This commit is contained in:
@@ -1,14 +1,14 @@
|
||||
"""Organization membership and board-access service helpers."""
|
||||
# ruff: noqa: D101, D103
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable
|
||||
from uuid import UUID
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.time import utcnow
|
||||
from app.db import crud
|
||||
@@ -19,7 +19,17 @@ from app.models.organization_invites import OrganizationInvite
|
||||
from app.models.organization_members import OrganizationMember
|
||||
from app.models.organizations import Organization
|
||||
from app.models.users import User
|
||||
from app.schemas.organizations import OrganizationBoardAccessSpec, OrganizationMemberAccessUpdate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.schemas.organizations import (
|
||||
OrganizationBoardAccessSpec,
|
||||
OrganizationMemberAccessUpdate,
|
||||
)
|
||||
|
||||
DEFAULT_ORG_NAME = "Personal"
|
||||
ADMIN_ROLES = {"owner", "admin"}
|
||||
@@ -63,7 +73,9 @@ async def get_member(
|
||||
).first(session)
|
||||
|
||||
|
||||
async def get_first_membership(session: AsyncSession, user_id: UUID) -> OrganizationMember | None:
|
||||
async def get_first_membership(
|
||||
session: AsyncSession, user_id: UUID,
|
||||
) -> OrganizationMember | None:
|
||||
return (
|
||||
await OrganizationMember.objects.filter_by(user_id=user_id)
|
||||
.order_by(col(OrganizationMember.created_at).asc())
|
||||
@@ -79,7 +91,9 @@ async def set_active_organization(
|
||||
) -> OrganizationMember:
|
||||
member = await get_member(session, user_id=user.id, organization_id=organization_id)
|
||||
if member is None:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No org access")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="No org access",
|
||||
)
|
||||
if user.active_organization_id != organization_id:
|
||||
user.active_organization_id = organization_id
|
||||
session.add(user)
|
||||
@@ -154,9 +168,10 @@ async def accept_invite(
|
||||
access_rows = list(
|
||||
await session.exec(
|
||||
select(OrganizationInviteBoardAccess).where(
|
||||
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
|
||||
)
|
||||
)
|
||||
col(OrganizationInviteBoardAccess.organization_invite_id)
|
||||
== invite.id,
|
||||
),
|
||||
),
|
||||
)
|
||||
for row in access_rows:
|
||||
session.add(
|
||||
@@ -167,7 +182,7 @@ async def accept_invite(
|
||||
can_write=row.can_write,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
invite.accepted_by_user_id = user.id
|
||||
@@ -182,7 +197,9 @@ async def accept_invite(
|
||||
return member
|
||||
|
||||
|
||||
async def ensure_member_for_user(session: AsyncSession, user: User) -> OrganizationMember:
|
||||
async def ensure_member_for_user(
|
||||
session: AsyncSession, user: User,
|
||||
) -> OrganizationMember:
|
||||
existing = await get_active_membership(session, user)
|
||||
if existing is not None:
|
||||
return existing
|
||||
@@ -196,7 +213,9 @@ async def ensure_member_for_user(session: AsyncSession, user: User) -> Organizat
|
||||
now = utcnow()
|
||||
member_count = (
|
||||
await session.exec(
|
||||
select(func.count()).where(col(OrganizationMember.organization_id) == org.id)
|
||||
select(func.count()).where(
|
||||
col(OrganizationMember.organization_id) == org.id,
|
||||
),
|
||||
)
|
||||
).one()
|
||||
is_first = int(member_count or 0) == 0
|
||||
@@ -257,30 +276,40 @@ async def require_board_access(
|
||||
board: Board,
|
||||
write: bool,
|
||||
) -> OrganizationMember:
|
||||
member = await get_member(session, user_id=user.id, organization_id=board.organization_id)
|
||||
member = await get_member(
|
||||
session, user_id=user.id, organization_id=board.organization_id,
|
||||
)
|
||||
if member is None:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No org access")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="No org access",
|
||||
)
|
||||
if not await has_board_access(session, member=member, board=board, write=write):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Board access denied")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Board access denied",
|
||||
)
|
||||
return member
|
||||
|
||||
|
||||
def board_access_filter(member: OrganizationMember, *, write: bool) -> ColumnElement[bool]:
|
||||
def board_access_filter(
|
||||
member: OrganizationMember, *, write: bool,
|
||||
) -> ColumnElement[bool]:
|
||||
if write and member_all_boards_write(member):
|
||||
return col(Board.organization_id) == member.organization_id
|
||||
if not write and member_all_boards_read(member):
|
||||
return col(Board.organization_id) == member.organization_id
|
||||
access_stmt = select(OrganizationBoardAccess.board_id).where(
|
||||
col(OrganizationBoardAccess.organization_member_id) == member.id
|
||||
col(OrganizationBoardAccess.organization_member_id) == member.id,
|
||||
)
|
||||
if write:
|
||||
access_stmt = access_stmt.where(col(OrganizationBoardAccess.can_write).is_(True))
|
||||
access_stmt = access_stmt.where(
|
||||
col(OrganizationBoardAccess.can_write).is_(True),
|
||||
)
|
||||
else:
|
||||
access_stmt = access_stmt.where(
|
||||
or_(
|
||||
col(OrganizationBoardAccess.can_read).is_(True),
|
||||
col(OrganizationBoardAccess.can_write).is_(True),
|
||||
)
|
||||
),
|
||||
)
|
||||
return col(Board.id).in_(access_stmt)
|
||||
|
||||
@@ -295,21 +324,25 @@ async def list_accessible_board_ids(
|
||||
not write and member_all_boards_read(member)
|
||||
):
|
||||
ids = await session.exec(
|
||||
select(Board.id).where(col(Board.organization_id) == member.organization_id)
|
||||
select(Board.id).where(
|
||||
col(Board.organization_id) == member.organization_id,
|
||||
),
|
||||
)
|
||||
return list(ids)
|
||||
|
||||
access_stmt = select(OrganizationBoardAccess.board_id).where(
|
||||
col(OrganizationBoardAccess.organization_member_id) == member.id
|
||||
col(OrganizationBoardAccess.organization_member_id) == member.id,
|
||||
)
|
||||
if write:
|
||||
access_stmt = access_stmt.where(col(OrganizationBoardAccess.can_write).is_(True))
|
||||
access_stmt = access_stmt.where(
|
||||
col(OrganizationBoardAccess.can_write).is_(True),
|
||||
)
|
||||
else:
|
||||
access_stmt = access_stmt.where(
|
||||
or_(
|
||||
col(OrganizationBoardAccess.can_read).is_(True),
|
||||
col(OrganizationBoardAccess.can_write).is_(True),
|
||||
)
|
||||
),
|
||||
)
|
||||
board_ids = await session.exec(access_stmt)
|
||||
return list(board_ids)
|
||||
@@ -337,18 +370,17 @@ async def apply_member_access_update(
|
||||
if update.all_boards_read or update.all_boards_write:
|
||||
return
|
||||
|
||||
rows: list[OrganizationBoardAccess] = []
|
||||
for entry in update.board_access:
|
||||
rows.append(
|
||||
OrganizationBoardAccess(
|
||||
organization_member_id=member.id,
|
||||
board_id=entry.board_id,
|
||||
can_read=entry.can_read,
|
||||
can_write=entry.can_write,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
rows = [
|
||||
OrganizationBoardAccess(
|
||||
organization_member_id=member.id,
|
||||
board_id=entry.board_id,
|
||||
can_read=entry.can_read,
|
||||
can_write=entry.can_write,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
for entry in update.board_access
|
||||
]
|
||||
session.add_all(rows)
|
||||
|
||||
|
||||
@@ -367,18 +399,17 @@ async def apply_invite_board_access(
|
||||
if invite.all_boards_read or invite.all_boards_write:
|
||||
return
|
||||
now = utcnow()
|
||||
rows: list[OrganizationInviteBoardAccess] = []
|
||||
for entry in entries:
|
||||
rows.append(
|
||||
OrganizationInviteBoardAccess(
|
||||
organization_invite_id=invite.id,
|
||||
board_id=entry.board_id,
|
||||
can_read=entry.can_read,
|
||||
can_write=entry.can_write,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
rows = [
|
||||
OrganizationInviteBoardAccess(
|
||||
organization_invite_id=invite.id,
|
||||
board_id=entry.board_id,
|
||||
can_read=entry.can_read,
|
||||
can_write=entry.can_write,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
for entry in entries
|
||||
]
|
||||
session.add_all(rows)
|
||||
|
||||
|
||||
@@ -423,9 +454,9 @@ async def apply_invite_to_member(
|
||||
access_rows = list(
|
||||
await session.exec(
|
||||
select(OrganizationInviteBoardAccess).where(
|
||||
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
|
||||
)
|
||||
)
|
||||
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id,
|
||||
),
|
||||
),
|
||||
)
|
||||
for row in access_rows:
|
||||
existing = (
|
||||
@@ -433,7 +464,7 @@ async def apply_invite_to_member(
|
||||
select(OrganizationBoardAccess).where(
|
||||
col(OrganizationBoardAccess.organization_member_id) == member.id,
|
||||
col(OrganizationBoardAccess.board_id) == row.board_id,
|
||||
)
|
||||
),
|
||||
)
|
||||
).first()
|
||||
can_write = bool(row.can_write)
|
||||
@@ -447,7 +478,7 @@ async def apply_invite_to_member(
|
||||
can_write=can_write,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
existing.can_read = bool(existing.can_read or can_read)
|
||||
|
||||
Reference in New Issue
Block a user