feat: refactor organization context usage in board and gateway endpoints

This commit is contained in:
Abhimanyu Saharan
2026-02-08 21:37:20 +05:30
parent 3f556802a9
commit 061563964d
7 changed files with 37 additions and 32 deletions

View File

@@ -18,6 +18,7 @@ from app.models.agents import Agent
from app.models.board_groups import BoardGroup
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.organization_members import OrganizationMember
from app.schemas.board_group_heartbeat import (
BoardGroupHeartbeatApply,
BoardGroupHeartbeatApplyResult,
@@ -29,6 +30,7 @@ from app.schemas.view_models import BoardGroupSnapshot
from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, sync_gateway_agent_heartbeats
from app.services.board_group_snapshot import build_group_snapshot
from app.services.organizations import (
OrganizationContext,
board_access_filter,
get_member,
is_org_admin,
@@ -49,7 +51,7 @@ async def _require_group_access(
session: AsyncSession,
*,
group_id: UUID,
member,
member: OrganizationMember,
write: bool,
) -> BoardGroup:
group = await session.get(BoardGroup, group_id)
@@ -80,7 +82,7 @@ async def _require_group_access(
@router.get("", response_model=DefaultLimitOffsetPage[BoardGroupRead])
async def list_board_groups(
session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_member),
ctx: OrganizationContext = Depends(require_org_member),
) -> DefaultLimitOffsetPage[BoardGroupRead]:
if member_all_boards_read(ctx.member):
statement = select(BoardGroup).where(col(BoardGroup.organization_id) == ctx.organization.id)
@@ -100,7 +102,7 @@ async def list_board_groups(
async def create_board_group(
payload: BoardGroupCreate,
session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_admin),
ctx: OrganizationContext = Depends(require_org_admin),
) -> BoardGroup:
data = payload.model_dump()
if not (data.get("slug") or "").strip():
@@ -113,7 +115,7 @@ async def create_board_group(
async def get_board_group(
group_id: UUID,
session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_member),
ctx: OrganizationContext = Depends(require_org_member),
) -> BoardGroup:
return await _require_group_access(session, group_id=group_id, member=ctx.member, write=False)
@@ -124,7 +126,7 @@ async def get_board_group_snapshot(
include_done: bool = False,
per_board_task_limit: int = 5,
session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_member),
ctx: OrganizationContext = Depends(require_org_member),
) -> BoardGroupSnapshot:
group = await _require_group_access(session, group_id=group_id, member=ctx.member, write=False)
if per_board_task_limit < 0:
@@ -253,7 +255,7 @@ async def update_board_group(
payload: BoardGroupUpdate,
group_id: UUID,
session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_admin),
ctx: OrganizationContext = Depends(require_org_admin),
) -> BoardGroup:
group = await _require_group_access(session, group_id=group_id, member=ctx.member, write=True)
updates = payload.model_dump(exclude_unset=True)
@@ -269,7 +271,7 @@ async def update_board_group(
async def delete_board_group(
group_id: UUID,
session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_admin),
ctx: OrganizationContext = Depends(require_org_admin),
) -> OkResponse:
await _require_group_access(session, group_id=group_id, member=ctx.member, write=True)

View File

@@ -43,7 +43,7 @@ from app.schemas.pagination import DefaultLimitOffsetPage
from app.schemas.view_models import BoardGroupSnapshot, BoardSnapshot
from app.services.board_group_snapshot import build_board_group_snapshot
from app.services.board_snapshot import build_board_snapshot
from app.services.organizations import board_access_filter
from app.services.organizations import OrganizationContext, board_access_filter
router = APIRouter(prefix="/boards", tags=["boards"])
@@ -81,7 +81,7 @@ async def _require_gateway(
async def _require_gateway_for_create(
payload: BoardCreate,
ctx=Depends(require_org_admin),
ctx: OrganizationContext = Depends(require_org_admin),
session: AsyncSession = Depends(get_session),
) -> Gateway:
return await _require_gateway(session, payload.gateway_id, organization_id=ctx.organization.id)
@@ -109,7 +109,7 @@ async def _require_board_group(
async def _require_board_group_for_create(
payload: BoardCreate,
ctx=Depends(require_org_admin),
ctx: OrganizationContext = Depends(require_org_admin),
session: AsyncSession = Depends(get_session),
) -> BoardGroup | None:
if payload.board_group_id is None:
@@ -220,7 +220,7 @@ async def list_boards(
gateway_id: UUID | None = Query(default=None),
board_group_id: UUID | None = Query(default=None),
session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_member),
ctx: OrganizationContext = Depends(require_org_member),
) -> DefaultLimitOffsetPage[BoardRead]:
statement = select(Board).where(board_access_filter(ctx.member, write=False))
if gateway_id is not None:
@@ -237,7 +237,7 @@ async def create_board(
_gateway: Gateway = Depends(_require_gateway_for_create),
_board_group: BoardGroup | None = Depends(_require_board_group_for_create),
session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_admin),
ctx: OrganizationContext = Depends(require_org_admin),
) -> Board:
data = payload.model_dump()
data["organization_id"] = ctx.organization.id

View File

@@ -25,6 +25,7 @@ from app.schemas.gateways import (
)
from app.schemas.pagination import DefaultLimitOffsetPage
from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_main_agent
from app.services.organizations import OrganizationContext
from app.services.template_sync import sync_gateway_templates as sync_gateway_templates_service
router = APIRouter(prefix="/gateways", tags=["gateways"])
@@ -131,7 +132,7 @@ async def _ensure_main_agent(
@router.get("", response_model=DefaultLimitOffsetPage[GatewayRead])
async def list_gateways(
session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_admin),
ctx: OrganizationContext = Depends(require_org_admin),
) -> DefaultLimitOffsetPage[GatewayRead]:
statement = (
select(Gateway)
@@ -146,7 +147,7 @@ async def create_gateway(
payload: GatewayCreate,
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
ctx=Depends(require_org_admin),
ctx: OrganizationContext = Depends(require_org_admin),
) -> Gateway:
data = payload.model_dump()
data["organization_id"] = ctx.organization.id
@@ -162,7 +163,7 @@ async def create_gateway(
async def get_gateway(
gateway_id: UUID,
session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_admin),
ctx: OrganizationContext = Depends(require_org_admin),
) -> Gateway:
gateway = await session.get(Gateway, gateway_id)
if gateway is None or gateway.organization_id != ctx.organization.id:
@@ -176,7 +177,7 @@ async def update_gateway(
payload: GatewayUpdate,
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
ctx=Depends(require_org_admin),
ctx: OrganizationContext = Depends(require_org_admin),
) -> Gateway:
gateway = await session.get(Gateway, gateway_id)
if gateway is None or gateway.organization_id != ctx.organization.id:
@@ -210,7 +211,7 @@ async def sync_gateway_templates(
board_id: UUID | None = Query(default=None),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
ctx=Depends(require_org_admin),
ctx: OrganizationContext = Depends(require_org_admin),
) -> GatewayTemplatesSyncResult:
gateway = await session.get(Gateway, gateway_id)
if gateway is None or gateway.organization_id != ctx.organization.id:
@@ -231,7 +232,7 @@ async def sync_gateway_templates(
async def delete_gateway(
gateway_id: UUID,
session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_admin),
ctx: OrganizationContext = Depends(require_org_admin),
) -> OkResponse:
gateway = await session.get(Gateway, gateway_id)
if gateway is None or gateway.organization_id != ctx.organization.id:

View File

@@ -26,7 +26,7 @@ from app.schemas.metrics import (
DashboardWipRangeSeries,
DashboardWipSeriesSet,
)
from app.services.organizations import list_accessible_board_ids
from app.services.organizations import OrganizationContext, list_accessible_board_ids
router = APIRouter(prefix="/metrics", tags=["metrics"])
@@ -304,7 +304,7 @@ async def _tasks_in_progress(session: AsyncSession, board_ids: list[UUID]) -> in
async def dashboard_metrics(
range: Literal["24h", "7d"] = Query(default="24h"),
session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_member),
ctx: OrganizationContext = Depends(require_org_member),
) -> DashboardMetrics:
primary = _resolve_range(range)
comparison = _comparison_range(range)

View File

@@ -5,7 +5,7 @@ from typing import Any, Sequence
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func
from sqlalchemy import delete, func
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -168,7 +168,7 @@ async def get_my_membership(
)
model = _member_to_read(ctx.member, user)
model.board_access = [
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) # type: ignore[name-defined]
OrganizationBoardAccessRead.model_validate(row, from_attributes=True)
for row in access_rows
]
return model
@@ -216,7 +216,7 @@ async def get_org_member(
)
model = _member_to_read(member, user)
model.board_access = [
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) # type: ignore[name-defined]
OrganizationBoardAccessRead.model_validate(row, from_attributes=True)
for row in access_rows
]
return model
@@ -351,9 +351,9 @@ async def revoke_org_invite(
if invite is None or invite.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
await session.execute(
OrganizationInviteBoardAccess.__table__.delete().where(
delete(OrganizationInviteBoardAccess).where(
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
)
),
)
await session.delete(invite)
await session.commit()