refactor: replace DefaultLimitOffsetPage with LimitOffsetPage in multiple files and update timezone handling to use UTC

This commit is contained in:
Abhimanyu Saharan
2026-02-09 20:40:17 +05:30
parent 1f105c19ab
commit 020d02fa22
51 changed files with 302 additions and 192 deletions

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
from collections import deque from collections import deque
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from uuid import UUID from uuid import UUID
@@ -36,6 +36,7 @@ from app.services.organizations import (
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncIterator, Sequence from collections.abc import AsyncIterator, Sequence
from fastapi_pagination.limit_offset import LimitOffsetPage
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter(prefix="/activity", tags=["activity"]) router = APIRouter(prefix="/activity", tags=["activity"])
@@ -63,7 +64,7 @@ def _parse_since(value: str | None) -> datetime | None:
except ValueError: except ValueError:
return None return None
if parsed.tzinfo is not None: if parsed.tzinfo is not None:
return parsed.astimezone(timezone.utc).replace(tzinfo=None) return parsed.astimezone(UTC).replace(tzinfo=None)
return parsed return parsed
@@ -145,7 +146,7 @@ async def _fetch_task_comment_events(
async def list_activity( async def list_activity(
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
actor: ActorContext = ACTOR_DEP, actor: ActorContext = ACTOR_DEP,
) -> DefaultLimitOffsetPage[ActivityEventRead]: ) -> LimitOffsetPage[ActivityEventRead]:
"""List activity events visible to the calling actor.""" """List activity events visible to the calling actor."""
statement = select(ActivityEvent) statement = select(ActivityEvent)
if actor.actor_type == "agent" and actor.agent: if actor.actor_type == "agent" and actor.agent:
@@ -174,7 +175,7 @@ async def list_task_comment_feed(
board_id: UUID | None = BOARD_ID_QUERY, board_id: UUID | None = BOARD_ID_QUERY,
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_MEMBER_DEP, ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> DefaultLimitOffsetPage[ActivityTaskCommentFeedItemRead]: ) -> LimitOffsetPage[ActivityTaskCommentFeedItemRead]:
"""List task-comment feed items for accessible boards.""" """List task-comment feed items for accessible boards."""
statement = ( statement = (
select(ActivityEvent, Task, Board, Agent) select(ActivityEvent, Task, Board, Agent)

View File

@@ -76,6 +76,7 @@ if TYPE_CHECKING:
from collections.abc import Sequence from collections.abc import Sequence
from uuid import UUID from uuid import UUID
from fastapi_pagination.limit_offset import LimitOffsetPage
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.activity_events import ActivityEvent from app.models.activity_events import ActivityEvent
@@ -222,7 +223,7 @@ async def _require_gateway_board(
async def list_boards( async def list_boards(
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP, agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> DefaultLimitOffsetPage[BoardRead]: ) -> LimitOffsetPage[BoardRead]:
"""List boards visible to the authenticated agent.""" """List boards visible to the authenticated agent."""
statement = select(Board) statement = select(Board)
if agent_ctx.agent.board_id: if agent_ctx.agent.board_id:
@@ -246,7 +247,7 @@ async def list_agents(
board_id: UUID | None = BOARD_ID_QUERY, board_id: UUID | None = BOARD_ID_QUERY,
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP, agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> DefaultLimitOffsetPage[AgentRead]: ) -> LimitOffsetPage[AgentRead]:
"""List agents, optionally filtered to a board.""" """List agents, optionally filtered to a board."""
statement = select(Agent) statement = select(Agent)
if agent_ctx.agent.board_id: if agent_ctx.agent.board_id:
@@ -277,7 +278,7 @@ async def list_tasks(
board: Board = BOARD_DEP, board: Board = BOARD_DEP,
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP, agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> DefaultLimitOffsetPage[TaskRead]: ) -> LimitOffsetPage[TaskRead]:
"""List tasks on a board with optional status and assignment filters.""" """List tasks on a board with optional status and assignment filters."""
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
return await tasks_api.list_tasks( return await tasks_api.list_tasks(
@@ -414,7 +415,7 @@ async def list_task_comments(
task: Task = TASK_DEP, task: Task = TASK_DEP,
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP, agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> DefaultLimitOffsetPage[TaskCommentRead]: ) -> LimitOffsetPage[TaskCommentRead]:
"""List comments for a task visible to the authenticated agent.""" """List comments for a task visible to the authenticated agent."""
if ( if (
agent_ctx.agent.board_id agent_ctx.agent.board_id
@@ -460,7 +461,7 @@ async def list_board_memory(
board: Board = BOARD_DEP, board: Board = BOARD_DEP,
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP, agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> DefaultLimitOffsetPage[BoardMemoryRead]: ) -> LimitOffsetPage[BoardMemoryRead]:
"""List board memory entries with optional chat filtering.""" """List board memory entries with optional chat filtering."""
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
return await board_memory_api.list_board_memory( return await board_memory_api.list_board_memory(
@@ -497,7 +498,7 @@ async def list_approvals(
board: Board = BOARD_DEP, board: Board = BOARD_DEP,
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = AGENT_CTX_DEP, agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> DefaultLimitOffsetPage[ApprovalRead]: ) -> LimitOffsetPage[ApprovalRead]:
"""List approvals for a board.""" """List approvals for a board."""
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
return await approvals_api.list_approvals( return await approvals_api.list_approvals(
@@ -960,12 +961,12 @@ async def broadcast_gateway_lead_message(
sent = 0 sent = 0
failed = 0 failed = 0
async def _send_to_board(board: Board) -> GatewayLeadBroadcastBoardResult: async def _send_to_board(target_board: Board) -> GatewayLeadBroadcastBoardResult:
try: try:
lead, _lead_created = await ensure_board_lead_agent( lead, _lead_created = await ensure_board_lead_agent(
session, session,
request=LeadAgentRequest( request=LeadAgentRequest(
board=board, board=target_board,
gateway=gateway, gateway=gateway,
config=config, config=config,
user=None, user=None,
@@ -975,14 +976,14 @@ async def broadcast_gateway_lead_message(
lead_session_key = _require_lead_session_key(lead) lead_session_key = _require_lead_session_key(lead)
message = ( message = (
f"{header}\n" f"{header}\n"
f"Board: {board.name}\n" f"Board: {target_board.name}\n"
f"Board ID: {board.id}\n" f"Board ID: {target_board.id}\n"
f"From agent: {agent_ctx.agent.name}\n" f"From agent: {agent_ctx.agent.name}\n"
f"{correlation_line}\n" f"{correlation_line}\n"
f"{payload.content.strip()}\n\n" f"{payload.content.strip()}\n\n"
"Reply to the gateway main by writing a NON-chat memory item " "Reply to the gateway main by writing a NON-chat memory item "
"on this board:\n" "on this board:\n"
f"POST {base_url}/api/v1/agent/boards/{board.id}/memory\n" f"POST {base_url}/api/v1/agent/boards/{target_board.id}/memory\n"
f'Body: {{"content":"...","tags":{tags_json},' f'Body: {{"content":"...","tags":{tags_json},'
f'"source":"{reply_source}"}}\n' f'"source":"{reply_source}"}}\n'
"Do NOT reply in OpenClaw chat." "Do NOT reply in OpenClaw chat."
@@ -990,14 +991,14 @@ async def broadcast_gateway_lead_message(
await ensure_session(lead_session_key, config=config, label=lead.name) await ensure_session(lead_session_key, config=config, label=lead.name)
await send_message(message, session_key=lead_session_key, config=config) await send_message(message, session_key=lead_session_key, config=config)
return GatewayLeadBroadcastBoardResult( return GatewayLeadBroadcastBoardResult(
board_id=board.id, board_id=target_board.id,
lead_agent_id=lead.id, lead_agent_id=lead.id,
lead_agent_name=lead.name, lead_agent_name=lead.name,
ok=True, ok=True,
) )
except (HTTPException, OpenClawGatewayError, ValueError) as exc: except (HTTPException, OpenClawGatewayError, ValueError) as exc:
return GatewayLeadBroadcastBoardResult( return GatewayLeadBroadcastBoardResult(
board_id=board.id, board_id=target_board.id,
ok=False, ok=False,
error=str(exc), error=str(exc),
) )

View File

@@ -6,7 +6,7 @@ import asyncio
import json import json
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta, timezone from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from uuid import UUID, uuid4 from uuid import UUID, uuid4
@@ -65,6 +65,7 @@ from app.services.organizations import (
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncIterator, Sequence from collections.abc import AsyncIterator, Sequence
from fastapi_pagination.limit_offset import LimitOffsetPage
from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import ColumnElement
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar from sqlmodel.sql.expression import SelectOfScalar
@@ -115,7 +116,7 @@ def _parse_since(value: str | None) -> datetime | None:
except ValueError: except ValueError:
return None return None
if parsed.tzinfo is not None: if parsed.tzinfo is not None:
return parsed.astimezone(timezone.utc).replace(tzinfo=None) return parsed.astimezone(UTC).replace(tzinfo=None)
return parsed return parsed
@@ -564,7 +565,7 @@ async def _validate_agent_update_inputs(
updates: dict[str, Any], updates: dict[str, Any],
make_main: bool | None, make_main: bool | None,
) -> None: ) -> None:
if make_main is True and not is_org_admin(ctx.member): if make_main and not is_org_admin(ctx.member):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if "status" in updates: if "status" in updates:
raise HTTPException( raise HTTPException(
@@ -597,7 +598,7 @@ async def _apply_agent_update_mutations(
) )
gateway_for_main: Gateway | None = None gateway_for_main: Gateway | None = None
if make_main is True: if make_main:
board_source = updates.get("board_id") or agent.board_id board_source = updates.get("board_id") or agent.board_id
board_for_main = await _require_board(session, board_source) board_for_main = await _require_board(session, board_source)
gateway_for_main, _ = await _require_gateway(session, board_for_main) gateway_for_main, _ = await _require_gateway(session, board_for_main)
@@ -605,10 +606,10 @@ async def _apply_agent_update_mutations(
agent.is_board_lead = False agent.is_board_lead = False
agent.openclaw_session_id = gateway_for_main.main_session_key agent.openclaw_session_id = gateway_for_main.main_session_key
main_gateway = gateway_for_main main_gateway = gateway_for_main
elif make_main is False: elif make_main is not None:
agent.openclaw_session_id = None agent.openclaw_session_id = None
if make_main is not True and "board_id" in updates: if not make_main and "board_id" in updates:
await _require_board(session, updates["board_id"]) await _require_board(session, updates["board_id"])
for key, value in updates.items(): for key, value in updates.items():
setattr(agent, key, value) setattr(agent, key, value)
@@ -633,7 +634,7 @@ async def _resolve_agent_update_target(
main_gateway: Gateway | None, main_gateway: Gateway | None,
gateway_for_main: Gateway | None, gateway_for_main: Gateway | None,
) -> _AgentUpdateProvisionTarget: ) -> _AgentUpdateProvisionTarget:
if make_main is True: if make_main:
if gateway_for_main is None: if gateway_for_main is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -955,7 +956,7 @@ async def list_agents(
gateway_id: UUID | None = GATEWAY_ID_QUERY, gateway_id: UUID | None = GATEWAY_ID_QUERY,
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP, ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> DefaultLimitOffsetPage[AgentRead]: ) -> LimitOffsetPage[AgentRead]:
"""List agents visible to the active organization admin.""" """List agents visible to the active organization admin."""
main_session_keys = await _get_gateway_main_session_keys(session) main_session_keys = await _get_gateway_main_session_keys(session)
board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False) board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False)

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
@@ -35,6 +35,7 @@ from app.schemas.pagination import DefaultLimitOffsetPage
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from fastapi_pagination.limit_offset import LimitOffsetPage
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.boards import Board from app.models.boards import Board
@@ -79,7 +80,7 @@ def _parse_since(value: str | None) -> datetime | None:
except ValueError: except ValueError:
return None return None
if parsed.tzinfo is not None: if parsed.tzinfo is not None:
return parsed.astimezone(timezone.utc).replace(tzinfo=None) return parsed.astimezone(UTC).replace(tzinfo=None)
return parsed return parsed
@@ -118,7 +119,7 @@ async def list_approvals(
board: Board = BOARD_READ_DEP, board: Board = BOARD_READ_DEP,
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
_actor: ActorContext = ACTOR_DEP, _actor: ActorContext = ACTOR_DEP,
) -> DefaultLimitOffsetPage[ApprovalRead]: ) -> LimitOffsetPage[ApprovalRead]:
"""List approvals for a board, optionally filtering by status.""" """List approvals for a board, optionally filtering by status."""
statement = Approval.objects.filter_by(board_id=board.id) statement = Approval.objects.filter_by(board_id=board.id)
if status_filter: if status_filter:

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
@@ -53,6 +53,7 @@ from app.services.organizations import (
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from fastapi_pagination.limit_offset import LimitOffsetPage
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.services.organizations import OrganizationContext from app.services.organizations import OrganizationContext
@@ -90,7 +91,7 @@ def _parse_since(value: str | None) -> datetime | None:
except ValueError: except ValueError:
return None return None
if parsed.tzinfo is not None: if parsed.tzinfo is not None:
return parsed.astimezone(timezone.utc).replace(tzinfo=None) return parsed.astimezone(UTC).replace(tzinfo=None)
return parsed return parsed
@@ -343,7 +344,7 @@ async def list_board_group_memory(
is_chat: bool | None = IS_CHAT_QUERY, is_chat: bool | None = IS_CHAT_QUERY,
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_MEMBER_DEP, ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> DefaultLimitOffsetPage[BoardGroupMemoryRead]: ) -> LimitOffsetPage[BoardGroupMemoryRead]:
"""List board-group memory entries for a specific group.""" """List board-group memory entries for a specific group."""
await _require_group_access(session, group_id=group_id, ctx=ctx, write=False) await _require_group_access(session, group_id=group_id, ctx=ctx, write=False)
statement = ( statement = (
@@ -439,7 +440,7 @@ async def list_board_group_memory_for_board(
is_chat: bool | None = IS_CHAT_QUERY, is_chat: bool | None = IS_CHAT_QUERY,
board: Board = BOARD_READ_DEP, board: Board = BOARD_READ_DEP,
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
) -> DefaultLimitOffsetPage[BoardGroupMemoryRead]: ) -> LimitOffsetPage[BoardGroupMemoryRead]:
"""List memory entries for the board's linked group.""" """List memory entries for the board's linked group."""
group_id = board.board_group_id group_id = board.board_group_id
if group_id is None: if group_id is None:

View File

@@ -50,6 +50,7 @@ from app.services.organizations import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from fastapi_pagination.limit_offset import LimitOffsetPage
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.organization_members import OrganizationMember from app.models.organization_members import OrganizationMember
@@ -103,7 +104,7 @@ async def _require_group_access(
async def list_board_groups( async def list_board_groups(
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_MEMBER_DEP, ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> DefaultLimitOffsetPage[BoardGroupRead]: ) -> LimitOffsetPage[BoardGroupRead]:
"""List board groups in the active organization.""" """List board groups in the active organization."""
if member_all_boards_read(ctx.member): if member_all_boards_read(ctx.member):
statement = select(BoardGroup).where( statement = select(BoardGroup).where(

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
@@ -39,6 +39,7 @@ from app.services.mentions import extract_mentions, matches_agent_mention
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from fastapi_pagination.limit_offset import LimitOffsetPage
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.boards import Board from app.models.boards import Board
@@ -67,7 +68,7 @@ def _parse_since(value: str | None) -> datetime | None:
except ValueError: except ValueError:
return None return None
if parsed.tzinfo is not None: if parsed.tzinfo is not None:
return parsed.astimezone(timezone.utc).replace(tzinfo=None) return parsed.astimezone(UTC).replace(tzinfo=None)
return parsed return parsed
@@ -250,7 +251,7 @@ async def list_board_memory(
board: Board = BOARD_READ_DEP, board: Board = BOARD_READ_DEP,
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
_actor: ActorContext = ACTOR_DEP, _actor: ActorContext = ACTOR_DEP,
) -> DefaultLimitOffsetPage[BoardMemoryRead]: ) -> LimitOffsetPage[BoardMemoryRead]:
"""List board memory entries, optionally filtering chat entries.""" """List board memory entries, optionally filtering chat entries."""
statement = ( statement = (
BoardMemory.objects.filter_by(board_id=board.id) BoardMemory.objects.filter_by(board_id=board.id)

View File

@@ -50,6 +50,7 @@ from app.services.board_snapshot import build_board_snapshot
from app.services.organizations import OrganizationContext, board_access_filter from app.services.organizations import OrganizationContext, board_access_filter
if TYPE_CHECKING: if TYPE_CHECKING:
from fastapi_pagination.limit_offset import LimitOffsetPage
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter(prefix="/boards", tags=["boards"]) router = APIRouter(prefix="/boards", tags=["boards"])
@@ -246,7 +247,7 @@ async def list_boards(
board_group_id: UUID | None = BOARD_GROUP_ID_QUERY, board_group_id: UUID | None = BOARD_GROUP_ID_QUERY,
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_MEMBER_DEP, ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> DefaultLimitOffsetPage[BoardRead]: ) -> LimitOffsetPage[BoardRead]:
"""List boards visible to the current organization member.""" """List boards visible to the current organization member."""
statement = select(Board).where(board_access_filter(ctx.member, write=False)) statement = select(Board).where(board_access_filter(ctx.member, write=False))
if gateway_id is not None: if gateway_id is not None:

View File

@@ -46,6 +46,7 @@ from app.services.template_sync import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from fastapi_pagination.limit_offset import LimitOffsetPage
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.services.organizations import OrganizationContext from app.services.organizations import OrganizationContext
@@ -224,7 +225,7 @@ async def _ensure_main_agent(
async def list_gateways( async def list_gateways(
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP, ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> DefaultLimitOffsetPage[GatewayRead]: ) -> LimitOffsetPage[GatewayRead]:
"""List gateways for the caller's organization.""" """List gateways for the caller's organization."""
statement = ( statement = (
Gateway.objects.filter_by(organization_id=ctx.organization.id) Gateway.objects.filter_by(organization_id=ctx.organization.id)

View File

@@ -8,7 +8,8 @@ from typing import Literal
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from sqlalchemy import DateTime, case, cast, func from sqlalchemy import DateTime, case, func
from sqlalchemy import cast as sql_cast
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -152,7 +153,7 @@ async def _query_cycle_time(
board_ids: list[UUID], board_ids: list[UUID],
) -> DashboardRangeSeries: ) -> DashboardRangeSeries:
bucket_col = func.date_trunc(range_spec.bucket, Task.updated_at).label("bucket") bucket_col = func.date_trunc(range_spec.bucket, Task.updated_at).label("bucket")
in_progress = cast(Task.in_progress_at, DateTime) in_progress = sql_cast(Task.in_progress_at, DateTime)
duration_hours = func.extract("epoch", Task.updated_at - in_progress) / 3600.0 duration_hours = func.extract("epoch", Task.updated_at - in_progress) / 3600.0
statement = ( statement = (
select(bucket_col, func.avg(duration_hours)) select(bucket_col, func.avg(duration_hours))
@@ -249,7 +250,7 @@ async def _median_cycle_time_7d(
) -> float | None: ) -> float | None:
now = utcnow() now = utcnow()
start = now - timedelta(days=7) start = now - timedelta(days=7)
in_progress = cast(Task.in_progress_at, DateTime) in_progress = sql_cast(Task.in_progress_at, DateTime)
duration_hours = func.extract("epoch", Task.updated_at - in_progress) / 3600.0 duration_hours = func.extract("epoch", Task.updated_at - in_progress) / 3600.0
statement = ( statement = (
select(func.percentile_cont(0.5).within_group(duration_hours)) select(func.percentile_cont(0.5).within_group(duration_hours))

View File

@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import secrets import secrets
from typing import TYPE_CHECKING, Any, Sequence from typing import TYPE_CHECKING, Any
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
@@ -65,6 +65,9 @@ from app.services.organizations import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence
from fastapi_pagination.limit_offset import LimitOffsetPage
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.auth import AuthContext from app.core.auth import AuthContext
@@ -369,7 +372,7 @@ async def get_my_membership(
async def list_org_members( async def list_org_members(
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_MEMBER_DEP, ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> DefaultLimitOffsetPage[OrganizationMemberRead]: ) -> LimitOffsetPage[OrganizationMemberRead]:
"""List members for the active organization.""" """List members for the active organization."""
statement = ( statement = (
select(OrganizationMember, User) select(OrganizationMember, User)
@@ -542,7 +545,7 @@ async def remove_org_member(
async def list_org_invites( async def list_org_invites(
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP, ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> DefaultLimitOffsetPage[OrganizationInviteRead]: ) -> LimitOffsetPage[OrganizationInviteRead]:
"""List pending invites for the active organization.""" """List pending invites for the active organization."""
statement = ( statement = (
OrganizationInvite.objects.filter_by(organization_id=ctx.organization.id) OrganizationInvite.objects.filter_by(organization_id=ctx.organization.id)

View File

@@ -3,13 +3,15 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, TypeVar from typing import TYPE_CHECKING, Any, Generic, TypeVar
from fastapi import HTTPException, status from fastapi import HTTPException, status
from app.db.queryset import QuerySet, qs from app.db.queryset import QuerySet, qs
if TYPE_CHECKING: if TYPE_CHECKING:
from sqlalchemy.orm import Mapped
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar from sqlmodel.sql.expression import SelectOfScalar
@@ -27,11 +29,17 @@ class APIQuerySet(Generic[ModelT]):
"""Expose the underlying SQL statement for advanced composition.""" """Expose the underlying SQL statement for advanced composition."""
return self.queryset.statement return self.queryset.statement
def filter(self, *criteria: object) -> APIQuerySet[ModelT]: def filter(
self,
*criteria: ColumnElement[bool] | bool,
) -> APIQuerySet[ModelT]:
"""Return a new queryset with additional SQL criteria applied.""" """Return a new queryset with additional SQL criteria applied."""
return APIQuerySet(self.queryset.filter(*criteria)) return APIQuerySet(self.queryset.filter(*criteria))
def order_by(self, *ordering: object) -> APIQuerySet[ModelT]: def order_by(
self,
*ordering: Mapped[Any] | ColumnElement[Any] | str,
) -> APIQuerySet[ModelT]:
"""Return a new queryset with ordering clauses applied.""" """Return a new queryset with ordering clauses applied."""
return APIQuerySet(self.queryset.order_by(*ordering)) return APIQuerySet(self.queryset.order_by(*ordering))

View File

@@ -7,8 +7,8 @@ import json
from collections import deque from collections import deque
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import TYPE_CHECKING, cast from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
@@ -67,8 +67,9 @@ from app.services.task_dependencies import (
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncIterator, Sequence from collections.abc import AsyncIterator, Sequence
from fastapi_pagination.limit_offset import LimitOffsetPage
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import Select, SelectOfScalar from sqlmodel.sql.expression import SelectOfScalar
from app.core.auth import AuthContext from app.core.auth import AuthContext
from app.models.users import User from app.models.users import User
@@ -85,6 +86,7 @@ TASK_EVENT_TYPES = {
SSE_SEEN_MAX = 2000 SSE_SEEN_MAX = 2000
TASK_SNIPPET_MAX_LEN = 500 TASK_SNIPPET_MAX_LEN = 500
TASK_SNIPPET_TRUNCATED_LEN = 497 TASK_SNIPPET_TRUNCATED_LEN = 497
TASK_EVENT_ROW_LEN = 2
BOARD_READ_DEP = Depends(get_board_for_actor_read) BOARD_READ_DEP = Depends(get_board_for_actor_read)
ACTOR_DEP = Depends(require_admin_or_agent) ACTOR_DEP = Depends(require_admin_or_agent)
SINCE_QUERY = Query(default=None) SINCE_QUERY = Query(default=None)
@@ -154,7 +156,7 @@ def _parse_since(value: str | None) -> datetime | None:
except ValueError: except ValueError:
return None return None
if parsed.tzinfo is not None: if parsed.tzinfo is not None:
return parsed.astimezone(timezone.utc).replace(tzinfo=None) return parsed.astimezone(UTC).replace(tzinfo=None)
return parsed return parsed
@@ -168,6 +170,24 @@ def _coerce_task_items(items: Sequence[object]) -> list[Task]:
return tasks return tasks
def _coerce_task_event_rows(
items: Sequence[object],
) -> list[tuple[ActivityEvent, Task | None]]:
rows: list[tuple[ActivityEvent, Task | None]] = []
for item in items:
if (
isinstance(item, tuple)
and len(item) == TASK_EVENT_ROW_LEN
and isinstance(item[0], ActivityEvent)
and (isinstance(item[1], Task) or item[1] is None)
):
rows.append((item[0], item[1]))
continue
msg = "Expected (ActivityEvent, Task | None) rows"
raise TypeError(msg)
return rows
async def _lead_was_mentioned( async def _lead_was_mentioned(
session: AsyncSession, session: AsyncSession,
task: Task, task: Task,
@@ -276,16 +296,16 @@ async def _fetch_task_events(
) )
if not task_ids: if not task_ids:
return [] return []
statement = cast( statement = (
"Select[tuple[ActivityEvent, Task | None]]",
select(ActivityEvent, Task) select(ActivityEvent, Task)
.outerjoin(Task, col(ActivityEvent.task_id) == col(Task.id)) .outerjoin(Task, col(ActivityEvent.task_id) == col(Task.id))
.where(col(ActivityEvent.task_id).in_(task_ids)) .where(col(ActivityEvent.task_id).in_(task_ids))
.where(col(ActivityEvent.event_type).in_(TASK_EVENT_TYPES)) .where(col(ActivityEvent.event_type).in_(TASK_EVENT_TYPES))
.where(col(ActivityEvent.created_at) >= since) .where(col(ActivityEvent.created_at) >= since)
.order_by(asc(col(ActivityEvent.created_at))), .order_by(asc(col(ActivityEvent.created_at)))
) )
return list(await session.exec(statement)) result = await session.execute(statement)
return _coerce_task_event_rows(list(result.tuples().all()))
def _serialize_comment(event: ActivityEvent) -> dict[str, object]: def _serialize_comment(event: ActivityEvent) -> dict[str, object]:
@@ -718,7 +738,7 @@ async def list_tasks(
board: Board = BOARD_READ_DEP, board: Board = BOARD_READ_DEP,
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
_actor: ActorContext = ACTOR_DEP, _actor: ActorContext = ACTOR_DEP,
) -> DefaultLimitOffsetPage[TaskRead]: ) -> LimitOffsetPage[TaskRead]:
"""List board tasks with optional status and assignment filters.""" """List board tasks with optional status and assignment filters."""
statement = _task_list_statement( statement = _task_list_statement(
board_id=board.id, board_id=board.id,
@@ -914,7 +934,7 @@ async def delete_task(
async def list_task_comments( async def list_task_comments(
task: Task = TASK_DEP, task: Task = TASK_DEP,
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
) -> DefaultLimitOffsetPage[TaskCommentRead]: ) -> LimitOffsetPage[TaskCommentRead]:
"""List comments for a task in chronological order.""" """List comments for a task in chronological order."""
statement = ( statement = (
select(ActivityEvent) select(ActivityEvent)

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import logging import logging
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, Final, cast from typing import TYPE_CHECKING, Any, Final
from uuid import uuid4 from uuid import uuid4
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
@@ -81,19 +81,49 @@ def install_error_handling(app: FastAPI) -> None:
app.add_exception_handler( app.add_exception_handler(
RequestValidationError, RequestValidationError,
cast(ExceptionHandler, _request_validation_handler), _request_validation_exception_handler,
) )
app.add_exception_handler( app.add_exception_handler(
ResponseValidationError, ResponseValidationError,
cast(ExceptionHandler, _response_validation_handler), _response_validation_exception_handler,
) )
app.add_exception_handler( app.add_exception_handler(
StarletteHTTPException, StarletteHTTPException,
cast(ExceptionHandler, _http_exception_handler), _http_exception_exception_handler,
) )
app.add_exception_handler(Exception, _unhandled_exception_handler) app.add_exception_handler(Exception, _unhandled_exception_handler)
async def _request_validation_exception_handler(
request: Request,
exc: Exception,
) -> Response:
if not isinstance(exc, RequestValidationError):
msg = "Expected RequestValidationError"
raise TypeError(msg)
return await _request_validation_handler(request, exc)
async def _response_validation_exception_handler(
request: Request,
exc: Exception,
) -> Response:
if not isinstance(exc, ResponseValidationError):
msg = "Expected ResponseValidationError"
raise TypeError(msg)
return await _response_validation_handler(request, exc)
async def _http_exception_exception_handler(
request: Request,
exc: Exception,
) -> Response:
if not isinstance(exc, StarletteHTTPException):
msg = "Expected StarletteHTTPException"
raise TypeError(msg)
return await _http_exception_handler(request, exc)
def _get_request_id(request: Request) -> str | None: def _get_request_id(request: Request) -> str | None:
request_id = getattr(request.state, "request_id", None) request_id = getattr(request.state, "request_id", None)
if isinstance(request_id, str) and request_id: if isinstance(request_id, str) and request_id:

View File

@@ -7,7 +7,7 @@ import logging
import os import os
import sys import sys
import time import time
from datetime import datetime, timezone from datetime import UTC, datetime
from types import TracebackType from types import TracebackType
from typing import Any from typing import Any
@@ -128,7 +128,7 @@ class JsonFormatter(logging.Formatter):
payload: dict[str, Any] = { payload: dict[str, Any] = {
"timestamp": datetime.fromtimestamp( "timestamp": datetime.fromtimestamp(
record.created, record.created,
tz=timezone.utc, tz=UTC,
).isoformat(), ).isoformat(),
"level": record.levelname, "level": record.levelname,
"logger": record.name, "logger": record.name,
@@ -153,6 +153,7 @@ class JsonFormatter(logging.Formatter):
class KeyValueFormatter(logging.Formatter): class KeyValueFormatter(logging.Formatter):
"""Formatter that appends extra fields as `key=value` pairs.""" """Formatter that appends extra fields as `key=value` pairs."""
# noinspection PyMethodMayBeStatic
def format(self, record: logging.LogRecord) -> str: def format(self, record: logging.LogRecord) -> str:
"""Render a log line with appended non-standard record fields.""" """Render a log line with appended non-standard record fields."""
base = super().format(record) base = super().format(record)

View File

@@ -132,7 +132,7 @@ async def save(
return obj return obj
async def delete(session: AsyncSession, obj: ModelT, *, commit: bool = True) -> None: async def delete(session: AsyncSession, obj: SQLModel, *, commit: bool = True) -> None:
"""Delete an object with optional commit.""" """Delete an object with optional commit."""
await session.delete(obj) await session.delete(obj)
if commit: if commit:

View File

@@ -3,19 +3,23 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Awaitable, Callable, Sequence from collections.abc import Awaitable, Callable, Sequence
from typing import TYPE_CHECKING, Any, TypeVar, cast from typing import TYPE_CHECKING, Any, TypeVar
from fastapi_pagination.ext.sqlalchemy import paginate as _paginate from fastapi_pagination.ext.sqlalchemy import paginate as _paginate
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
if TYPE_CHECKING: if TYPE_CHECKING:
from fastapi_pagination.limit_offset import LimitOffsetPage
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import Select, SelectOfScalar from sqlmodel.sql.expression import Select, SelectOfScalar
T = TypeVar("T") T = TypeVar("T")
Transformer = Callable[[Sequence[Any]], Sequence[Any] | Awaitable[Sequence[Any]]] Transformer = Callable[
[Sequence[Any]],
Sequence[Any] | Awaitable[Sequence[Any]],
]
async def paginate( async def paginate(
@@ -23,12 +27,7 @@ async def paginate(
statement: Select[Any] | SelectOfScalar[Any], statement: Select[Any] | SelectOfScalar[Any],
*, *,
transformer: Transformer | None = None, transformer: Transformer | None = None,
) -> DefaultLimitOffsetPage[T]: ) -> LimitOffsetPage[T]:
"""Execute a paginated query and cast to the project page type alias.""" """Execute a paginated query and cast to the project page type alias."""
# fastapi-pagination is not fully typed (it returns Any), but response_model page = await _paginate(session, statement, transformer=transformer)
# validation ensures runtime correctness. Centralize casts here to keep strict return DefaultLimitOffsetPage[T].model_validate(page)
# mypy clean.
return cast(
DefaultLimitOffsetPage[T],
await _paginate(session, statement, transformer=transformer),
)

View File

@@ -13,6 +13,8 @@ from app.db.queryset import QuerySet, qs
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Iterable from collections.abc import Iterable
from sqlalchemy.sql.elements import ColumnElement
ModelT = TypeVar("ModelT", bound=SQLModel) ModelT = TypeVar("ModelT", bound=SQLModel)
@@ -31,11 +33,17 @@ class ModelManager(Generic[ModelT]):
"""Return a queryset that yields no rows.""" """Return a queryset that yields no rows."""
return qs(self.model).filter(false()) return qs(self.model).filter(false())
def filter(self, *criteria: object) -> QuerySet[ModelT]: def filter(
self,
*criteria: ColumnElement[bool] | bool,
) -> QuerySet[ModelT]:
"""Return queryset filtered by SQL criteria expressions.""" """Return queryset filtered by SQL criteria expressions."""
return self.all().filter(*criteria) return self.all().filter(*criteria)
def where(self, *criteria: object) -> QuerySet[ModelT]: def where(
self,
*criteria: ColumnElement[bool] | bool,
) -> QuerySet[ModelT]:
"""Alias for `filter`.""" """Alias for `filter`."""
return self.filter(*criteria) return self.filter(*criteria)
@@ -76,6 +84,7 @@ class ModelManager(Generic[ModelT]):
class ManagerDescriptor(Generic[ModelT]): class ManagerDescriptor(Generic[ModelT]):
"""Descriptor that exposes a model-bound `ModelManager` as `.objects`.""" """Descriptor that exposes a model-bound `ModelManager` as `.objects`."""
# noinspection PyMethodMayBeStatic
def __get__(self, instance: object, owner: type[ModelT]) -> ModelManager[ModelT]: def __get__(self, instance: object, owner: type[ModelT]) -> ModelManager[ModelT]:
"""Return a fresh manager bound to the owning model class.""" """Return a fresh manager bound to the owning model class."""
return ModelManager(owner) return ModelManager(owner)

View File

@@ -3,11 +3,13 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast from typing import TYPE_CHECKING, Any, Generic, TypeVar
from sqlmodel import select from sqlmodel import select
if TYPE_CHECKING: if TYPE_CHECKING:
from sqlalchemy.orm import Mapped
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar from sqlmodel.sql.expression import SelectOfScalar
@@ -20,15 +22,18 @@ class QuerySet(Generic[ModelT]):
statement: SelectOfScalar[ModelT] statement: SelectOfScalar[ModelT]
def filter(self, *criteria: object) -> QuerySet[ModelT]: def filter(
self,
*criteria: ColumnElement[bool] | bool,
) -> QuerySet[ModelT]:
"""Return a new queryset with additional SQL criteria.""" """Return a new queryset with additional SQL criteria."""
statement = cast( statement = self.statement.where(*criteria)
"SelectOfScalar[ModelT]",
cast(Any, self.statement).where(*criteria),
)
return replace(self, statement=statement) return replace(self, statement=statement)
def where(self, *criteria: object) -> QuerySet[ModelT]: def where(
self,
*criteria: ColumnElement[bool] | bool,
) -> QuerySet[ModelT]:
"""Alias for `filter` to mirror SQLAlchemy naming.""" """Alias for `filter` to mirror SQLAlchemy naming."""
return self.filter(*criteria) return self.filter(*criteria)
@@ -37,12 +42,12 @@ class QuerySet(Generic[ModelT]):
statement = self.statement.filter_by(**kwargs) statement = self.statement.filter_by(**kwargs)
return replace(self, statement=statement) return replace(self, statement=statement)
def order_by(self, *ordering: object) -> QuerySet[ModelT]: def order_by(
self,
*ordering: Mapped[Any] | ColumnElement[Any] | str,
) -> QuerySet[ModelT]:
"""Return a new queryset with ordering clauses applied.""" """Return a new queryset with ordering clauses applied."""
statement = cast( statement = self.statement.order_by(*ordering)
"SelectOfScalar[ModelT]",
cast(Any, self.statement).order_by(*ordering),
)
return replace(self, statement=statement) return replace(self, statement=statement)
def limit(self, value: int) -> QuerySet[ModelT]: def limit(self, value: int) -> QuerySet[ModelT]:

View File

@@ -2,11 +2,11 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import anyio
from alembic import command from alembic import command
from alembic.config import Config from alembic.config import Config
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
@@ -65,11 +65,11 @@ async def init_db() -> None:
versions_dir = Path(__file__).resolve().parents[2] / "migrations" / "versions" versions_dir = Path(__file__).resolve().parents[2] / "migrations" / "versions"
if any(versions_dir.glob("*.py")): if any(versions_dir.glob("*.py")):
logger.info("Running migrations on startup") logger.info("Running migrations on startup")
await anyio.to_thread.run_sync(run_migrations) await asyncio.to_thread(run_migrations)
return return
logger.warning("No migration revisions found; falling back to create_all") logger.warning("No migration revisions found; falling back to create_all")
async with async_engine.begin() as conn: async with async_engine.connect() as conn, conn.begin():
await conn.run_sync(SQLModel.metadata.create_all) await conn.run_sync(SQLModel.metadata.create_all)

View File

@@ -56,7 +56,8 @@ async def _await_response(
data = json.loads(raw) data = json.loads(raw)
if data.get("type") == "res" and data.get("id") == request_id: if data.get("type") == "res" and data.get("id") == request_id:
if data.get("ok") is False: ok = data.get("ok")
if ok is not None and not ok:
error = data.get("error", {}).get("message", "Gateway error") error = data.get("error", {}).get("message", "Gateway error")
raise OpenClawGatewayError(error) raise OpenClawGatewayError(error)
return data.get("payload") return data.get("payload")
@@ -135,14 +136,14 @@ async def openclaw_call(
first_message = None first_message = None
try: try:
first_message = await asyncio.wait_for(ws.recv(), timeout=2) first_message = await asyncio.wait_for(ws.recv(), timeout=2)
except asyncio.TimeoutError: except TimeoutError:
first_message = None first_message = None
await _ensure_connected(ws, first_message, config) await _ensure_connected(ws, first_message, config)
return await _send_request(ws, method, params) return await _send_request(ws, method, params)
except OpenClawGatewayError: except OpenClawGatewayError:
raise raise
except ( except (
asyncio.TimeoutError, TimeoutError,
ConnectionError, ConnectionError,
OSError, OSError,
ValueError, ValueError,

View File

@@ -16,7 +16,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
class ActivityEvent(QueryModel, table=True): class ActivityEvent(QueryModel, table=True):
"""Discrete activity event tied to tasks and agents.""" """Discrete activity event tied to tasks and agents."""
__tablename__ = "activity_events" __tablename__ = "activity_events" # pyright: ignore[reportAssignmentType]
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
event_type: str = Field(index=True) event_type: str = Field(index=True)

View File

@@ -18,7 +18,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
class Agent(QueryModel, table=True): class Agent(QueryModel, table=True):
"""Agent configuration and lifecycle state persisted in the database.""" """Agent configuration and lifecycle state persisted in the database."""
__tablename__ = "agents" __tablename__ = "agents" # pyright: ignore[reportAssignmentType]
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
board_id: UUID | None = Field(default=None, foreign_key="boards.id", index=True) board_id: UUID | None = Field(default=None, foreign_key="boards.id", index=True)

View File

@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
class Approval(QueryModel, table=True): class Approval(QueryModel, table=True):
"""Approval request and decision metadata for gated operations.""" """Approval request and decision metadata for gated operations."""
__tablename__ = "approvals" __tablename__ = "approvals" # pyright: ignore[reportAssignmentType]
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
board_id: UUID = Field(foreign_key="boards.id", index=True) board_id: UUID = Field(foreign_key="boards.id", index=True)

View File

@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
class BoardGroupMemory(QueryModel, table=True): class BoardGroupMemory(QueryModel, table=True):
"""Persisted memory items associated with a board group.""" """Persisted memory items associated with a board group."""
__tablename__ = "board_group_memory" __tablename__ = "board_group_memory" # pyright: ignore[reportAssignmentType]
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
board_group_id: UUID = Field(foreign_key="board_groups.id", index=True) board_group_id: UUID = Field(foreign_key="board_groups.id", index=True)

View File

@@ -16,7 +16,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
class BoardGroup(TenantScoped, table=True): class BoardGroup(TenantScoped, table=True):
"""Logical grouping container for boards within an organization.""" """Logical grouping container for boards within an organization."""
__tablename__ = "board_groups" __tablename__ = "board_groups" # pyright: ignore[reportAssignmentType]
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
organization_id: UUID = Field(foreign_key="organizations.id", index=True) organization_id: UUID = Field(foreign_key="organizations.id", index=True)

View File

@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
class BoardMemory(QueryModel, table=True): class BoardMemory(QueryModel, table=True):
"""Persisted memory item attached directly to a board.""" """Persisted memory item attached directly to a board."""
__tablename__ = "board_memory" __tablename__ = "board_memory" # pyright: ignore[reportAssignmentType]
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
board_id: UUID = Field(foreign_key="boards.id", index=True) board_id: UUID = Field(foreign_key="boards.id", index=True)

View File

@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
class BoardOnboardingSession(QueryModel, table=True): class BoardOnboardingSession(QueryModel, table=True):
"""Persisted onboarding conversation and draft goal data for a board.""" """Persisted onboarding conversation and draft goal data for a board."""
__tablename__ = "board_onboarding_sessions" __tablename__ = "board_onboarding_sessions" # pyright: ignore[reportAssignmentType]
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
board_id: UUID = Field(foreign_key="boards.id", index=True) board_id: UUID = Field(foreign_key="boards.id", index=True)

View File

@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
class Board(TenantScoped, table=True): class Board(TenantScoped, table=True):
"""Primary board entity grouping tasks, agents, and goal metadata.""" """Primary board entity grouping tasks, agents, and goal metadata."""
__tablename__ = "boards" __tablename__ = "boards" # pyright: ignore[reportAssignmentType]
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
organization_id: UUID = Field(foreign_key="organizations.id", index=True) organization_id: UUID = Field(foreign_key="organizations.id", index=True)

View File

@@ -16,7 +16,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
class Gateway(QueryModel, table=True): class Gateway(QueryModel, table=True):
"""Configured external gateway endpoint and authentication settings.""" """Configured external gateway endpoint and authentication settings."""
__tablename__ = "gateways" __tablename__ = "gateways" # pyright: ignore[reportAssignmentType]
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
organization_id: UUID = Field(foreign_key="organizations.id", index=True) organization_id: UUID = Field(foreign_key="organizations.id", index=True)

View File

@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
class OrganizationBoardAccess(QueryModel, table=True): class OrganizationBoardAccess(QueryModel, table=True):
"""Member-specific board permissions within an organization.""" """Member-specific board permissions within an organization."""
__tablename__ = "organization_board_access" __tablename__ = "organization_board_access" # pyright: ignore[reportAssignmentType]
__table_args__ = ( __table_args__ = (
UniqueConstraint( UniqueConstraint(
"organization_member_id", "organization_member_id",

View File

@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
class OrganizationInviteBoardAccess(QueryModel, table=True): class OrganizationInviteBoardAccess(QueryModel, table=True):
"""Invite-specific board permissions applied after invite acceptance.""" """Invite-specific board permissions applied after invite acceptance."""
__tablename__ = "organization_invite_board_access" __tablename__ = "organization_invite_board_access" # pyright: ignore[reportAssignmentType]
__table_args__ = ( __table_args__ = (
UniqueConstraint( UniqueConstraint(
"organization_invite_id", "organization_invite_id",

View File

@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
class OrganizationInvite(QueryModel, table=True): class OrganizationInvite(QueryModel, table=True):
"""Invitation record granting prospective organization access.""" """Invitation record granting prospective organization access."""
__tablename__ = "organization_invites" __tablename__ = "organization_invites" # pyright: ignore[reportAssignmentType]
__table_args__ = (UniqueConstraint("token", name="uq_org_invites_token"),) __table_args__ = (UniqueConstraint("token", name="uq_org_invites_token"),)
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
class OrganizationMember(QueryModel, table=True): class OrganizationMember(QueryModel, table=True):
"""Membership row linking a user to an organization and permissions.""" """Membership row linking a user to an organization and permissions."""
__tablename__ = "organization_members" __tablename__ = "organization_members" # pyright: ignore[reportAssignmentType]
__table_args__ = ( __table_args__ = (
UniqueConstraint( UniqueConstraint(
"organization_id", "organization_id",

View File

@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
class Organization(QueryModel, table=True): class Organization(QueryModel, table=True):
"""Top-level organization tenant record.""" """Top-level organization tenant record."""
__tablename__ = "organizations" __tablename__ = "organizations" # pyright: ignore[reportAssignmentType]
__table_args__ = (UniqueConstraint("name", name="uq_organizations_name"),) __table_args__ = (UniqueConstraint("name", name="uq_organizations_name"),)
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
class TaskDependency(TenantScoped, table=True): class TaskDependency(TenantScoped, table=True):
"""Directed dependency edge between two tasks in the same board.""" """Directed dependency edge between two tasks in the same board."""
__tablename__ = "task_dependencies" __tablename__ = "task_dependencies" # pyright: ignore[reportAssignmentType]
__table_args__ = ( __table_args__ = (
UniqueConstraint( UniqueConstraint(
"task_id", "task_id",

View File

@@ -16,7 +16,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
class TaskFingerprint(QueryModel, table=True): class TaskFingerprint(QueryModel, table=True):
"""Hashed task-content fingerprint associated with a board and task.""" """Hashed task-content fingerprint associated with a board and task."""
__tablename__ = "task_fingerprints" __tablename__ = "task_fingerprints" # pyright: ignore[reportAssignmentType]
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
board_id: UUID = Field(foreign_key="boards.id", index=True) board_id: UUID = Field(foreign_key="boards.id", index=True)

View File

@@ -16,7 +16,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
class Task(TenantScoped, table=True): class Task(TenantScoped, table=True):
"""Board-scoped task entity with ownership, status, and timing fields.""" """Board-scoped task entity with ownership, status, and timing fields."""
__tablename__ = "tasks" __tablename__ = "tasks" # pyright: ignore[reportAssignmentType]
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
board_id: UUID | None = Field(default=None, foreign_key="boards.id", index=True) board_id: UUID | None = Field(default=None, foreign_key="boards.id", index=True)

View File

@@ -12,7 +12,7 @@ from app.models.base import QueryModel
class User(QueryModel, table=True): class User(QueryModel, table=True):
"""Application user account and profile attributes.""" """Application user account and profile attributes."""
__tablename__ = "users" __tablename__ = "users" # pyright: ignore[reportAssignmentType]
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
clerk_user_id: str = Field(index=True, unique=True) clerk_user_id: str = Field(index=True, unique=True)

View File

@@ -34,11 +34,13 @@ class BoardBase(SQLModel):
class BoardCreate(BoardBase): class BoardCreate(BoardBase):
"""Payload for creating a board.""" """Payload for creating a board."""
gateway_id: UUID gateway_id: UUID | None = None
@model_validator(mode="after") @model_validator(mode="after")
def validate_goal_fields(self) -> Self: def validate_goal_fields(self) -> Self:
"""Require goal details when creating a confirmed goal board.""" """Require gateway and goal details when creating a confirmed goal board."""
if self.gateway_id is None:
raise ValueError(_ERR_GATEWAY_REQUIRED)
if ( if (
self.board_type == "goal" self.board_type == "goal"
and self.goal_confirmed and self.goal_confirmed

View File

@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import TypeVar from typing import TYPE_CHECKING, TypeVar
from fastapi import Query from fastapi import Query
from fastapi_pagination.customization import CustomizedPage, UseParamsFields from fastapi_pagination.customization import CustomizedPage, UseParamsFields
@@ -14,10 +14,15 @@ T = TypeVar("T")
# Project-wide default pagination response model. # Project-wide default pagination response model.
# - Keep `limit` / `offset` naming (matches existing API conventions). # - Keep `limit` / `offset` naming (matches existing API conventions).
# - Cap list endpoints to 200 items per request (matches prior route-level constraints). # - Cap list endpoints to 200 items per request (matches prior route-level constraints).
DefaultLimitOffsetPage = CustomizedPage[ if TYPE_CHECKING:
LimitOffsetPage[T], # Type checkers treat this as a normal generic page type.
UseParamsFields( DefaultLimitOffsetPage = LimitOffsetPage
limit=Query(200, ge=1, le=200), else:
offset=Query(0, ge=0), # Runtime uses project-default query param bounds for all list endpoints.
), DefaultLimitOffsetPage = CustomizedPage[
] LimitOffsetPage[T],
UseParamsFields(
limit=Query(200, ge=1, le=200),
offset=Query(0, ge=0),
),
]

View File

@@ -738,7 +738,7 @@ def _should_include_bootstrap(
if not existing_files: if not existing_files:
return False return False
entry = existing_files.get("BOOTSTRAP.md") entry = existing_files.get("BOOTSTRAP.md")
return not (entry and entry.get("missing") is True) return not bool(entry and entry.get("missing"))
async def _set_agent_files( async def _set_agent_files(
@@ -753,7 +753,7 @@ async def _set_agent_files(
continue continue
if name in PRESERVE_AGENT_EDITABLE_FILES: if name in PRESERVE_AGENT_EDITABLE_FILES:
entry = existing_files.get(name) entry = existing_files.get(name)
if entry and entry.get("missing") is not True: if entry and not bool(entry.get("missing")):
continue continue
try: try:
await openclaw_call( await openclaw_call(

View File

@@ -117,20 +117,20 @@ def _has_cycle(nodes: Sequence[UUID], edges: Mapping[UUID, set[UUID]]) -> bool:
visited: set[UUID] = set() visited: set[UUID] = set()
in_stack: set[UUID] = set() in_stack: set[UUID] = set()
def dfs(node: UUID) -> bool: def dfs(current: UUID) -> bool:
if node in in_stack: if current in in_stack:
return True return True
if node in visited: if current in visited:
return False return False
visited.add(node) visited.add(current)
in_stack.add(node) in_stack.add(current)
for nxt in edges.get(node, set()): for nxt in edges.get(current, set()):
if dfs(nxt): if dfs(nxt):
return True return True
in_stack.remove(node) in_stack.remove(current)
return False return False
return any(dfs(node) for node in nodes) return any(dfs(start_node) for start_node in nodes)
async def validate_dependency_update( async def validate_dependency_update(

View File

@@ -132,8 +132,8 @@ class _GatewayBackoff:
def reset(self) -> None: def reset(self) -> None:
self._delay_s = self._base_delay_s self._delay_s = self._base_delay_s
@staticmethod
async def _attempt( async def _attempt(
self,
fn: Callable[[], Awaitable[T]], fn: Callable[[], Awaitable[T]],
) -> tuple[T | None, OpenClawGatewayError | None]: ) -> tuple[T | None, OpenClawGatewayError | None]:
try: try:

View File

@@ -4,16 +4,15 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from types import SimpleNamespace from typing import Any
from typing import TYPE_CHECKING, cast
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from app.api import board_groups from app.api import board_groups
from app.models.organization_members import OrganizationMember
if TYPE_CHECKING: from app.models.organizations import Organization
from sqlmodel.ext.asyncio.session import AsyncSession from app.services.organizations import OrganizationContext
@dataclass @dataclass
@@ -47,12 +46,20 @@ async def test_delete_board_group_cleans_group_memory_first(
_fake_require_group_access, _fake_require_group_access,
) )
session = _FakeSession() session: Any = _FakeSession()
ctx = SimpleNamespace(member=object()) org_id = uuid4()
ctx = OrganizationContext(
organization=Organization(id=org_id, name=f"org-{org_id}"),
member=OrganizationMember(
organization_id=org_id,
user_id=uuid4(),
role="admin",
),
)
await board_groups.delete_board_group( await board_groups.delete_board_group(
group_id=group_id, group_id=group_id,
session=cast("AsyncSession", session), session=session,
ctx=ctx, ctx=ctx,
) )

View File

@@ -4,7 +4,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, cast from typing import Any
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
@@ -12,9 +12,6 @@ import pytest
from app.api import boards from app.api import boards
from app.models.boards import Board from app.models.boards import Board
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
_NO_EXEC_RESULTS_ERROR = "No more exec_results left for session.exec" _NO_EXEC_RESULTS_ERROR = "No more exec_results left for session.exec"
@@ -47,7 +44,7 @@ class _FakeSession:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_board_cleans_org_board_access_rows() -> None: async def test_delete_board_cleans_org_board_access_rows() -> None:
"""Deleting a board should clear org-board access rows before commit.""" """Deleting a board should clear org-board access rows before commit."""
session = _FakeSession(exec_results=[[], []]) session: Any = _FakeSession(exec_results=[[], []])
board = Board( board = Board(
id=uuid4(), id=uuid4(),
organization_id=uuid4(), organization_id=uuid4(),
@@ -57,7 +54,7 @@ async def test_delete_board_cleans_org_board_access_rows() -> None:
) )
await boards.delete_board( await boards.delete_board(
session=cast("AsyncSession", session), session=session,
board=board, board=board,
) )

View File

@@ -50,7 +50,8 @@ async def test_get_session_rolls_back_on_dependency_error(monkeypatch: pytest.Mo
class _FakeDependencySession: class _FakeDependencySession:
rollbacks: int = 0 rollbacks: int = 0
def in_transaction(self) -> bool: @staticmethod
def in_transaction() -> bool:
return True return True
async def rollback(self) -> None: async def rollback(self) -> None:
@@ -89,16 +90,19 @@ async def test_create_rolls_back_when_commit_fails() -> None:
def add(self, value: Any) -> None: def add(self, value: Any) -> None:
self.added.append(value) self.added.append(value)
async def flush(self) -> None: @staticmethod
async def flush() -> None:
return None return None
async def commit(self) -> None: @staticmethod
async def commit() -> None:
raise _CommitError("commit failed") raise _CommitError("commit failed")
async def rollback(self) -> None: async def rollback(self) -> None:
self.rollback_calls += 1 self.rollback_calls += 1
async def refresh(self, _value: Any) -> None: @staticmethod
async def refresh(_value: Any) -> None:
return None return None
session = _FailCommitSession() session = _FailCommitSession()
@@ -124,7 +128,8 @@ async def test_delete_where_rolls_back_when_commit_fails() -> None:
self.exec_calls += 1 self.exec_calls += 1
return SimpleNamespace(rowcount=3) return SimpleNamespace(rowcount=3)
async def commit(self) -> None: @staticmethod
async def commit() -> None:
raise _CommitError("commit failed") raise _CommitError("commit failed")
async def rollback(self) -> None: async def rollback(self) -> None:

View File

@@ -4,17 +4,16 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from types import SimpleNamespace from typing import Any
from typing import TYPE_CHECKING, cast
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from fastapi import HTTPException, status from fastapi import HTTPException, status
from app.api import organizations from app.api import organizations
from app.models.organization_members import OrganizationMember
if TYPE_CHECKING: from app.models.organizations import Organization
from sqlmodel.ext.asyncio.session import AsyncSession from app.services.organizations import OrganizationContext
@dataclass @dataclass
@@ -35,15 +34,19 @@ class _FakeSession:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_my_org_cleans_dependents_before_organization_delete() -> None: async def test_delete_my_org_cleans_dependents_before_organization_delete() -> None:
"""Delete flow should remove dependent rows before the organization row.""" """Delete flow should remove dependent rows before the organization row."""
session = _FakeSession() session: Any = _FakeSession()
org_id = uuid4() org_id = uuid4()
ctx = SimpleNamespace( ctx = OrganizationContext(
organization=SimpleNamespace(id=org_id), organization=Organization(id=org_id, name=f"org-{org_id}"),
member=SimpleNamespace(role="owner"), member=OrganizationMember(
organization_id=org_id,
user_id=uuid4(),
role="owner",
),
) )
await organizations.delete_my_org( await organizations.delete_my_org(
session=cast("AsyncSession", session), session=session,
ctx=ctx, ctx=ctx,
) )
@@ -77,15 +80,20 @@ async def test_delete_my_org_cleans_dependents_before_organization_delete() -> N
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_my_org_requires_owner_role() -> None: async def test_delete_my_org_requires_owner_role() -> None:
"""Delete flow should reject non-owner members with HTTP 403.""" """Delete flow should reject non-owner members with HTTP 403."""
session = _FakeSession() session: Any = _FakeSession()
ctx = SimpleNamespace( org_id = uuid4()
organization=SimpleNamespace(id=uuid4()), ctx = OrganizationContext(
member=SimpleNamespace(role="admin"), organization=Organization(id=org_id, name=f"org-{org_id}"),
member=OrganizationMember(
organization_id=org_id,
user_id=uuid4(),
role="admin",
),
) )
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await organizations.delete_my_org( await organizations.delete_my_org(
session=cast("AsyncSession", session), session=session,
ctx=ctx, ctx=ctx,
) )

View File

@@ -3,16 +3,17 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from types import SimpleNamespace
from typing import Any from typing import Any
from uuid import uuid4 from uuid import UUID, uuid4
import pytest import pytest
from fastapi import HTTPException, status from fastapi import HTTPException, status
from app.api import organizations from app.api import organizations
from app.models.organization_members import OrganizationMember from app.models.organization_members import OrganizationMember
from app.models.organizations import Organization
from app.models.users import User from app.models.users import User
from app.services.organizations import OrganizationContext
@dataclass @dataclass
@@ -58,6 +59,17 @@ class _FakeSession:
self.committed += 1 self.committed += 1
def _make_ctx(*, org_id: UUID, user_id: UUID, role: str) -> OrganizationContext:
return OrganizationContext(
organization=Organization(id=org_id, name=f"org-{org_id}"),
member=OrganizationMember(
organization_id=org_id,
user_id=user_id,
role=role,
),
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_remove_org_member_deletes_member_access_and_member() -> None: async def test_remove_org_member_deletes_member_access_and_member() -> None:
org_id = uuid4() org_id = uuid4()
@@ -83,10 +95,7 @@ async def test_remove_org_member_deletes_member_access_and_member() -> None:
_FakeExecResult(first_value=fallback_org_id), _FakeExecResult(first_value=fallback_org_id),
], ],
) )
ctx = SimpleNamespace( ctx = _make_ctx(org_id=org_id, user_id=actor_user_id, role="admin")
organization=SimpleNamespace(id=org_id),
member=SimpleNamespace(user_id=actor_user_id, role="admin"),
)
await organizations.remove_org_member(member_id=member_id, session=session, ctx=ctx) await organizations.remove_org_member(member_id=member_id, session=session, ctx=ctx)
@@ -109,10 +118,7 @@ async def test_remove_org_member_disallows_self_removal() -> None:
role="member", role="member",
) )
session = _FakeSession(exec_results=[_FakeExecResult(first_value=member)]) session = _FakeSession(exec_results=[_FakeExecResult(first_value=member)])
ctx = SimpleNamespace( ctx = _make_ctx(org_id=org_id, user_id=user_id, role="owner")
organization=SimpleNamespace(id=org_id),
member=SimpleNamespace(user_id=user_id, role="owner"),
)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await organizations.remove_org_member(member_id=member.id, session=session, ctx=ctx) await organizations.remove_org_member(member_id=member.id, session=session, ctx=ctx)
@@ -133,10 +139,7 @@ async def test_remove_org_member_requires_owner_to_remove_owner() -> None:
role="owner", role="owner",
) )
session = _FakeSession(exec_results=[_FakeExecResult(first_value=member)]) session = _FakeSession(exec_results=[_FakeExecResult(first_value=member)])
ctx = SimpleNamespace( ctx = _make_ctx(org_id=org_id, user_id=uuid4(), role="admin")
organization=SimpleNamespace(id=org_id),
member=SimpleNamespace(user_id=uuid4(), role="admin"),
)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await organizations.remove_org_member(member_id=member.id, session=session, ctx=ctx) await organizations.remove_org_member(member_id=member.id, session=session, ctx=ctx)
@@ -162,10 +165,7 @@ async def test_remove_org_member_rejects_removing_last_owner() -> None:
_FakeExecResult(all_values=[member]), _FakeExecResult(all_values=[member]),
], ],
) )
ctx = SimpleNamespace( ctx = _make_ctx(org_id=org_id, user_id=uuid4(), role="owner")
organization=SimpleNamespace(id=org_id),
member=SimpleNamespace(user_id=uuid4(), role="owner"),
)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await organizations.remove_org_member(member_id=member.id, session=session, ctx=ctx) await organizations.remove_org_member(member_id=member.id, session=session, ctx=ctx)

View File

@@ -17,8 +17,8 @@ async def test_request_id_middleware_passes_through_non_http_scope() -> None:
middleware = RequestIdMiddleware(app) middleware = RequestIdMiddleware(app)
scope = {"type": "websocket", "headers": []} request_scope = {"type": "websocket", "headers": []}
await middleware(scope, lambda: None, lambda message: None) # type: ignore[arg-type] await middleware(request_scope, lambda: None, lambda message: None) # type: ignore[arg-type]
assert called is True assert called is True
@@ -40,11 +40,11 @@ async def test_request_id_middleware_ignores_blank_client_header_and_generates_o
middleware = RequestIdMiddleware(app) middleware = RequestIdMiddleware(app)
scope = { request_scope = {
"type": "http", "type": "http",
"headers": [(REQUEST_ID_HEADER.lower().encode("latin-1"), b" ")], "headers": [(REQUEST_ID_HEADER.lower().encode("latin-1"), b" ")],
} }
await middleware(scope, lambda: None, send) await middleware(request_scope, lambda: None, send)
assert isinstance(captured_request_id, str) and captured_request_id assert isinstance(captured_request_id, str) and captured_request_id
# Header should reflect the generated id, not the blank one. # Header should reflect the generated id, not the blank one.
@@ -78,8 +78,8 @@ async def test_request_id_middleware_does_not_duplicate_existing_header() -> Non
middleware = RequestIdMiddleware(app) middleware = RequestIdMiddleware(app)
scope = {"type": "http", "headers": []} request_scope = {"type": "http", "headers": []}
await middleware(scope, lambda: None, send) await middleware(request_scope, lambda: None, send)
assert sent_start is True assert sent_start is True
assert start_headers is not None assert start_headers is not None

View File

@@ -20,7 +20,7 @@ from app.services import task_dependencies as td
async def _make_engine() -> AsyncEngine: async def _make_engine() -> AsyncEngine:
# Single shared in-memory db per engine. # Single shared in-memory db per engine.
engine = create_async_engine("sqlite+aiosqlite:///:memory:") engine = create_async_engine("sqlite+aiosqlite:///:memory:")
async with engine.begin() as conn: async with engine.connect() as conn, conn.begin():
await conn.run_sync(SQLModel.metadata.create_all) await conn.run_sync(SQLModel.metadata.create_all)
return engine return engine