feat: update context dependency to OrganizationContext in activity and agents modules

This commit is contained in:
Abhimanyu Saharan
2026-02-08 21:39:02 +05:30
parent 061563964d
commit 7addc32ff9
2 changed files with 71 additions and 61 deletions

View File

@@ -24,7 +24,11 @@ from app.models.boards import Board
from app.models.tasks import Task from app.models.tasks import Task
from app.schemas.activity_events import ActivityEventRead, ActivityTaskCommentFeedItemRead from app.schemas.activity_events import ActivityEventRead, ActivityTaskCommentFeedItemRead
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
from app.services.organizations import get_active_membership, list_accessible_board_ids from app.services.organizations import (
OrganizationContext,
get_active_membership,
list_accessible_board_ids,
)
router = APIRouter(prefix="/activity", tags=["activity"]) router = APIRouter(prefix="/activity", tags=["activity"])
@@ -134,7 +138,7 @@ async def list_activity(
async def list_task_comment_feed( async def list_task_comment_feed(
board_id: UUID | None = Query(default=None), board_id: UUID | None = Query(default=None),
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_member), ctx: OrganizationContext = Depends(require_org_member),
) -> DefaultLimitOffsetPage[ActivityTaskCommentFeedItemRead]: ) -> DefaultLimitOffsetPage[ActivityTaskCommentFeedItemRead]:
statement = ( statement = (
select(ActivityEvent, Task, Board, Agent) select(ActivityEvent, Task, Board, Agent)
@@ -168,7 +172,7 @@ async def stream_task_comment_feed(
board_id: UUID | None = Query(default=None), board_id: UUID | None = Query(default=None),
since: str | None = Query(default=None), since: str | None = Query(default=None),
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_member), ctx: OrganizationContext = Depends(require_org_member),
) -> EventSourceResponse: ) -> EventSourceResponse:
since_dt = _parse_since(since) or utcnow() since_dt = _parse_since(since) or utcnow()
board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False) board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False)

View File

@@ -10,6 +10,7 @@ from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import asc, or_, update from sqlalchemy import asc, or_, update
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
@@ -245,7 +246,7 @@ async def _require_agent_access(
session: AsyncSession, session: AsyncSession,
*, *,
agent: Agent, agent: Agent,
ctx, ctx: OrganizationContext,
write: bool, write: bool,
) -> None: ) -> None:
if agent.board_id is None: if agent.board_id is None:
@@ -302,7 +303,7 @@ async def list_agents(
board_id: UUID | None = Query(default=None), board_id: UUID | None = Query(default=None),
gateway_id: UUID | None = Query(default=None), gateway_id: UUID | None = Query(default=None),
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> DefaultLimitOffsetPage[AgentRead]: ) -> DefaultLimitOffsetPage[AgentRead]:
main_session_keys = await _get_gateway_main_session_keys(session) main_session_keys = await _get_gateway_main_session_keys(session)
board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False) board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False)
@@ -311,7 +312,7 @@ async def list_agents(
if not board_ids: if not board_ids:
statement = select(Agent).where(col(Agent.id).is_(None)) statement = select(Agent).where(col(Agent.id).is_(None))
else: else:
base_filter = col(Agent.board_id).in_(board_ids) base_filter: ColumnElement[bool] = col(Agent.board_id).in_(board_ids)
if is_org_admin(ctx.member): if is_org_admin(ctx.member):
gateway_keys = select(Gateway.main_session_key).where( gateway_keys = select(Gateway.main_session_key).where(
col(Gateway.organization_id) == ctx.organization.id col(Gateway.organization_id) == ctx.organization.id
@@ -342,7 +343,7 @@ async def stream_agents(
board_id: UUID | None = Query(default=None), board_id: UUID | None = Query(default=None),
since: str | None = Query(default=None), since: str | None = Query(default=None),
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> EventSourceResponse: ) -> EventSourceResponse:
since_dt = _parse_since(since) or utcnow() since_dt = _parse_since(since) or utcnow()
last_seen = since_dt last_seen = since_dt
@@ -528,7 +529,7 @@ async def create_agent(
async def get_agent( async def get_agent(
agent_id: str, agent_id: str,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> AgentRead: ) -> AgentRead:
agent = await session.get(Agent, agent_id) agent = await session.get(Agent, agent_id)
if agent is None: if agent is None:
@@ -545,7 +546,7 @@ async def update_agent(
force: bool = False, force: bool = False,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(get_auth_context),
ctx=Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> AgentRead: ) -> AgentRead:
agent = await session.get(Agent, agent_id) agent = await session.get(Agent, agent_id)
if agent is None: if agent is None:
@@ -841,12 +842,12 @@ async def heartbeat_or_create_agent(
except Exception as exc: # pragma: no cover - unexpected provisioning errors except Exception as exc: # pragma: no cover - unexpected provisioning errors
_record_instruction_failure(session, agent, str(exc), "provision") _record_instruction_failure(session, agent, str(exc), "provision")
await session.commit() await session.commit()
elif actor.actor_type == "user": else:
if actor.actor_type == "user":
ctx = await _require_user_context(session, actor.user) ctx = await _require_user_context(session, actor.user)
await _require_agent_access(session, agent=agent, ctx=ctx, write=True) await _require_agent_access(session, agent=agent, ctx=ctx, write=True)
elif actor.actor_type == "agent" and actor.agent and actor.agent.id != agent.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) if agent.agent_token_hash is None:
elif agent.agent_token_hash is None and actor.actor_type == "user":
raw_token = generate_agent_token() raw_token = generate_agent_token()
agent.agent_token_hash = hash_agent_token(raw_token) agent.agent_token_hash = hash_agent_token(raw_token)
if agent.heartbeat_config is None: if agent.heartbeat_config is None:
@@ -860,11 +861,13 @@ async def heartbeat_or_create_agent(
board = await _require_board( board = await _require_board(
session, session,
str(agent.board_id) if agent.board_id else None, str(agent.board_id) if agent.board_id else None,
user=actor.user if actor.actor_type == "user" else None, user=actor.user,
write=actor.actor_type == "user", write=True,
) )
gateway, client_config = await _require_gateway(session, board) gateway, client_config = await _require_gateway(session, board)
await provision_agent(agent, board, gateway, raw_token, actor.user, action="provision") await provision_agent(
agent, board, gateway, raw_token, actor.user, action="provision"
)
await _send_wakeup_message(agent, client_config, verb="provisioned") await _send_wakeup_message(agent, client_config, verb="provisioned")
agent.provision_confirm_token_hash = None agent.provision_confirm_token_hash = None
agent.provision_requested_at = None agent.provision_requested_at = None
@@ -891,7 +894,10 @@ async def heartbeat_or_create_agent(
except Exception as exc: # pragma: no cover - unexpected provisioning errors except Exception as exc: # pragma: no cover - unexpected provisioning errors
_record_instruction_failure(session, agent, str(exc), "provision") _record_instruction_failure(session, agent, str(exc), "provision")
await session.commit() await session.commit()
elif not agent.openclaw_session_id: elif actor.actor_type == "agent" and actor.agent and actor.agent.id != agent.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if not agent.openclaw_session_id:
board = await _require_board( board = await _require_board(
session, session,
str(agent.board_id) if agent.board_id else None, str(agent.board_id) if agent.board_id else None,
@@ -934,7 +940,7 @@ async def heartbeat_or_create_agent(
async def delete_agent( async def delete_agent(
agent_id: str, agent_id: str,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> OkResponse: ) -> OkResponse:
agent = await session.get(Agent, agent_id) agent = await session.get(Agent, agent_id)
if agent is None: if agent is None: