refactor: replace DefaultLimitOffsetPage with LimitOffsetPage in multiple files and update timezone handling to use UTC
This commit is contained in:
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from datetime import datetime, timezone
|
from datetime import UTC, datetime
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@@ -36,6 +36,7 @@ from app.services.organizations import (
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import AsyncIterator, Sequence
|
from collections.abc import AsyncIterator, Sequence
|
||||||
|
|
||||||
|
from fastapi_pagination.limit_offset import LimitOffsetPage
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
router = APIRouter(prefix="/activity", tags=["activity"])
|
router = APIRouter(prefix="/activity", tags=["activity"])
|
||||||
@@ -63,7 +64,7 @@ def _parse_since(value: str | None) -> datetime | None:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
return None
|
return None
|
||||||
if parsed.tzinfo is not None:
|
if parsed.tzinfo is not None:
|
||||||
return parsed.astimezone(timezone.utc).replace(tzinfo=None)
|
return parsed.astimezone(UTC).replace(tzinfo=None)
|
||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
@@ -145,7 +146,7 @@ async def _fetch_task_comment_events(
|
|||||||
async def list_activity(
|
async def list_activity(
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
actor: ActorContext = ACTOR_DEP,
|
actor: ActorContext = ACTOR_DEP,
|
||||||
) -> DefaultLimitOffsetPage[ActivityEventRead]:
|
) -> LimitOffsetPage[ActivityEventRead]:
|
||||||
"""List activity events visible to the calling actor."""
|
"""List activity events visible to the calling actor."""
|
||||||
statement = select(ActivityEvent)
|
statement = select(ActivityEvent)
|
||||||
if actor.actor_type == "agent" and actor.agent:
|
if actor.actor_type == "agent" and actor.agent:
|
||||||
@@ -174,7 +175,7 @@ async def list_task_comment_feed(
|
|||||||
board_id: UUID | None = BOARD_ID_QUERY,
|
board_id: UUID | None = BOARD_ID_QUERY,
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
ctx: OrganizationContext = ORG_MEMBER_DEP,
|
ctx: OrganizationContext = ORG_MEMBER_DEP,
|
||||||
) -> DefaultLimitOffsetPage[ActivityTaskCommentFeedItemRead]:
|
) -> LimitOffsetPage[ActivityTaskCommentFeedItemRead]:
|
||||||
"""List task-comment feed items for accessible boards."""
|
"""List task-comment feed items for accessible boards."""
|
||||||
statement = (
|
statement = (
|
||||||
select(ActivityEvent, Task, Board, Agent)
|
select(ActivityEvent, Task, Board, Agent)
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ if TYPE_CHECKING:
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi_pagination.limit_offset import LimitOffsetPage
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.models.activity_events import ActivityEvent
|
from app.models.activity_events import ActivityEvent
|
||||||
@@ -222,7 +223,7 @@ async def _require_gateway_board(
|
|||||||
async def list_boards(
|
async def list_boards(
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||||
) -> DefaultLimitOffsetPage[BoardRead]:
|
) -> LimitOffsetPage[BoardRead]:
|
||||||
"""List boards visible to the authenticated agent."""
|
"""List boards visible to the authenticated agent."""
|
||||||
statement = select(Board)
|
statement = select(Board)
|
||||||
if agent_ctx.agent.board_id:
|
if agent_ctx.agent.board_id:
|
||||||
@@ -246,7 +247,7 @@ async def list_agents(
|
|||||||
board_id: UUID | None = BOARD_ID_QUERY,
|
board_id: UUID | None = BOARD_ID_QUERY,
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||||
) -> DefaultLimitOffsetPage[AgentRead]:
|
) -> LimitOffsetPage[AgentRead]:
|
||||||
"""List agents, optionally filtered to a board."""
|
"""List agents, optionally filtered to a board."""
|
||||||
statement = select(Agent)
|
statement = select(Agent)
|
||||||
if agent_ctx.agent.board_id:
|
if agent_ctx.agent.board_id:
|
||||||
@@ -277,7 +278,7 @@ async def list_tasks(
|
|||||||
board: Board = BOARD_DEP,
|
board: Board = BOARD_DEP,
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||||
) -> DefaultLimitOffsetPage[TaskRead]:
|
) -> LimitOffsetPage[TaskRead]:
|
||||||
"""List tasks on a board with optional status and assignment filters."""
|
"""List tasks on a board with optional status and assignment filters."""
|
||||||
_guard_board_access(agent_ctx, board)
|
_guard_board_access(agent_ctx, board)
|
||||||
return await tasks_api.list_tasks(
|
return await tasks_api.list_tasks(
|
||||||
@@ -414,7 +415,7 @@ async def list_task_comments(
|
|||||||
task: Task = TASK_DEP,
|
task: Task = TASK_DEP,
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||||
) -> DefaultLimitOffsetPage[TaskCommentRead]:
|
) -> LimitOffsetPage[TaskCommentRead]:
|
||||||
"""List comments for a task visible to the authenticated agent."""
|
"""List comments for a task visible to the authenticated agent."""
|
||||||
if (
|
if (
|
||||||
agent_ctx.agent.board_id
|
agent_ctx.agent.board_id
|
||||||
@@ -460,7 +461,7 @@ async def list_board_memory(
|
|||||||
board: Board = BOARD_DEP,
|
board: Board = BOARD_DEP,
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||||
) -> DefaultLimitOffsetPage[BoardMemoryRead]:
|
) -> LimitOffsetPage[BoardMemoryRead]:
|
||||||
"""List board memory entries with optional chat filtering."""
|
"""List board memory entries with optional chat filtering."""
|
||||||
_guard_board_access(agent_ctx, board)
|
_guard_board_access(agent_ctx, board)
|
||||||
return await board_memory_api.list_board_memory(
|
return await board_memory_api.list_board_memory(
|
||||||
@@ -497,7 +498,7 @@ async def list_approvals(
|
|||||||
board: Board = BOARD_DEP,
|
board: Board = BOARD_DEP,
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||||
) -> DefaultLimitOffsetPage[ApprovalRead]:
|
) -> LimitOffsetPage[ApprovalRead]:
|
||||||
"""List approvals for a board."""
|
"""List approvals for a board."""
|
||||||
_guard_board_access(agent_ctx, board)
|
_guard_board_access(agent_ctx, board)
|
||||||
return await approvals_api.list_approvals(
|
return await approvals_api.list_approvals(
|
||||||
@@ -960,12 +961,12 @@ async def broadcast_gateway_lead_message(
|
|||||||
sent = 0
|
sent = 0
|
||||||
failed = 0
|
failed = 0
|
||||||
|
|
||||||
async def _send_to_board(board: Board) -> GatewayLeadBroadcastBoardResult:
|
async def _send_to_board(target_board: Board) -> GatewayLeadBroadcastBoardResult:
|
||||||
try:
|
try:
|
||||||
lead, _lead_created = await ensure_board_lead_agent(
|
lead, _lead_created = await ensure_board_lead_agent(
|
||||||
session,
|
session,
|
||||||
request=LeadAgentRequest(
|
request=LeadAgentRequest(
|
||||||
board=board,
|
board=target_board,
|
||||||
gateway=gateway,
|
gateway=gateway,
|
||||||
config=config,
|
config=config,
|
||||||
user=None,
|
user=None,
|
||||||
@@ -975,14 +976,14 @@ async def broadcast_gateway_lead_message(
|
|||||||
lead_session_key = _require_lead_session_key(lead)
|
lead_session_key = _require_lead_session_key(lead)
|
||||||
message = (
|
message = (
|
||||||
f"{header}\n"
|
f"{header}\n"
|
||||||
f"Board: {board.name}\n"
|
f"Board: {target_board.name}\n"
|
||||||
f"Board ID: {board.id}\n"
|
f"Board ID: {target_board.id}\n"
|
||||||
f"From agent: {agent_ctx.agent.name}\n"
|
f"From agent: {agent_ctx.agent.name}\n"
|
||||||
f"{correlation_line}\n"
|
f"{correlation_line}\n"
|
||||||
f"{payload.content.strip()}\n\n"
|
f"{payload.content.strip()}\n\n"
|
||||||
"Reply to the gateway main by writing a NON-chat memory item "
|
"Reply to the gateway main by writing a NON-chat memory item "
|
||||||
"on this board:\n"
|
"on this board:\n"
|
||||||
f"POST {base_url}/api/v1/agent/boards/{board.id}/memory\n"
|
f"POST {base_url}/api/v1/agent/boards/{target_board.id}/memory\n"
|
||||||
f'Body: {{"content":"...","tags":{tags_json},'
|
f'Body: {{"content":"...","tags":{tags_json},'
|
||||||
f'"source":"{reply_source}"}}\n'
|
f'"source":"{reply_source}"}}\n'
|
||||||
"Do NOT reply in OpenClaw chat."
|
"Do NOT reply in OpenClaw chat."
|
||||||
@@ -990,14 +991,14 @@ async def broadcast_gateway_lead_message(
|
|||||||
await ensure_session(lead_session_key, config=config, label=lead.name)
|
await ensure_session(lead_session_key, config=config, label=lead.name)
|
||||||
await send_message(message, session_key=lead_session_key, config=config)
|
await send_message(message, session_key=lead_session_key, config=config)
|
||||||
return GatewayLeadBroadcastBoardResult(
|
return GatewayLeadBroadcastBoardResult(
|
||||||
board_id=board.id,
|
board_id=target_board.id,
|
||||||
lead_agent_id=lead.id,
|
lead_agent_id=lead.id,
|
||||||
lead_agent_name=lead.name,
|
lead_agent_name=lead.name,
|
||||||
ok=True,
|
ok=True,
|
||||||
)
|
)
|
||||||
except (HTTPException, OpenClawGatewayError, ValueError) as exc:
|
except (HTTPException, OpenClawGatewayError, ValueError) as exc:
|
||||||
return GatewayLeadBroadcastBoardResult(
|
return GatewayLeadBroadcastBoardResult(
|
||||||
board_id=board.id,
|
board_id=target_board.id,
|
||||||
ok=False,
|
ok=False,
|
||||||
error=str(exc),
|
error=str(exc),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import UTC, datetime, timedelta
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
@@ -65,6 +65,7 @@ from app.services.organizations import (
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import AsyncIterator, Sequence
|
from collections.abc import AsyncIterator, Sequence
|
||||||
|
|
||||||
|
from fastapi_pagination.limit_offset import LimitOffsetPage
|
||||||
from sqlalchemy.sql.elements import ColumnElement
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
from sqlmodel.sql.expression import SelectOfScalar
|
from sqlmodel.sql.expression import SelectOfScalar
|
||||||
@@ -115,7 +116,7 @@ def _parse_since(value: str | None) -> datetime | None:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
return None
|
return None
|
||||||
if parsed.tzinfo is not None:
|
if parsed.tzinfo is not None:
|
||||||
return parsed.astimezone(timezone.utc).replace(tzinfo=None)
|
return parsed.astimezone(UTC).replace(tzinfo=None)
|
||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
@@ -564,7 +565,7 @@ async def _validate_agent_update_inputs(
|
|||||||
updates: dict[str, Any],
|
updates: dict[str, Any],
|
||||||
make_main: bool | None,
|
make_main: bool | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if make_main is True and not is_org_admin(ctx.member):
|
if make_main and not is_org_admin(ctx.member):
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||||
if "status" in updates:
|
if "status" in updates:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -597,7 +598,7 @@ async def _apply_agent_update_mutations(
|
|||||||
)
|
)
|
||||||
gateway_for_main: Gateway | None = None
|
gateway_for_main: Gateway | None = None
|
||||||
|
|
||||||
if make_main is True:
|
if make_main:
|
||||||
board_source = updates.get("board_id") or agent.board_id
|
board_source = updates.get("board_id") or agent.board_id
|
||||||
board_for_main = await _require_board(session, board_source)
|
board_for_main = await _require_board(session, board_source)
|
||||||
gateway_for_main, _ = await _require_gateway(session, board_for_main)
|
gateway_for_main, _ = await _require_gateway(session, board_for_main)
|
||||||
@@ -605,10 +606,10 @@ async def _apply_agent_update_mutations(
|
|||||||
agent.is_board_lead = False
|
agent.is_board_lead = False
|
||||||
agent.openclaw_session_id = gateway_for_main.main_session_key
|
agent.openclaw_session_id = gateway_for_main.main_session_key
|
||||||
main_gateway = gateway_for_main
|
main_gateway = gateway_for_main
|
||||||
elif make_main is False:
|
elif make_main is not None:
|
||||||
agent.openclaw_session_id = None
|
agent.openclaw_session_id = None
|
||||||
|
|
||||||
if make_main is not True and "board_id" in updates:
|
if not make_main and "board_id" in updates:
|
||||||
await _require_board(session, updates["board_id"])
|
await _require_board(session, updates["board_id"])
|
||||||
for key, value in updates.items():
|
for key, value in updates.items():
|
||||||
setattr(agent, key, value)
|
setattr(agent, key, value)
|
||||||
@@ -633,7 +634,7 @@ async def _resolve_agent_update_target(
|
|||||||
main_gateway: Gateway | None,
|
main_gateway: Gateway | None,
|
||||||
gateway_for_main: Gateway | None,
|
gateway_for_main: Gateway | None,
|
||||||
) -> _AgentUpdateProvisionTarget:
|
) -> _AgentUpdateProvisionTarget:
|
||||||
if make_main is True:
|
if make_main:
|
||||||
if gateway_for_main is None:
|
if gateway_for_main is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
@@ -955,7 +956,7 @@ async def list_agents(
|
|||||||
gateway_id: UUID | None = GATEWAY_ID_QUERY,
|
gateway_id: UUID | None = GATEWAY_ID_QUERY,
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||||
) -> DefaultLimitOffsetPage[AgentRead]:
|
) -> LimitOffsetPage[AgentRead]:
|
||||||
"""List agents visible to the active organization admin."""
|
"""List agents visible to the active organization admin."""
|
||||||
main_session_keys = await _get_gateway_main_session_keys(session)
|
main_session_keys = await _get_gateway_main_session_keys(session)
|
||||||
board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False)
|
board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from datetime import datetime, timezone
|
from datetime import UTC, datetime
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@@ -35,6 +35,7 @@ from app.schemas.pagination import DefaultLimitOffsetPage
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
|
from fastapi_pagination.limit_offset import LimitOffsetPage
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.models.boards import Board
|
from app.models.boards import Board
|
||||||
@@ -79,7 +80,7 @@ def _parse_since(value: str | None) -> datetime | None:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
return None
|
return None
|
||||||
if parsed.tzinfo is not None:
|
if parsed.tzinfo is not None:
|
||||||
return parsed.astimezone(timezone.utc).replace(tzinfo=None)
|
return parsed.astimezone(UTC).replace(tzinfo=None)
|
||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
@@ -118,7 +119,7 @@ async def list_approvals(
|
|||||||
board: Board = BOARD_READ_DEP,
|
board: Board = BOARD_READ_DEP,
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
_actor: ActorContext = ACTOR_DEP,
|
_actor: ActorContext = ACTOR_DEP,
|
||||||
) -> DefaultLimitOffsetPage[ApprovalRead]:
|
) -> LimitOffsetPage[ApprovalRead]:
|
||||||
"""List approvals for a board, optionally filtering by status."""
|
"""List approvals for a board, optionally filtering by status."""
|
||||||
statement = Approval.objects.filter_by(board_id=board.id)
|
statement = Approval.objects.filter_by(board_id=board.id)
|
||||||
if status_filter:
|
if status_filter:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import UTC, datetime
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@@ -53,6 +53,7 @@ from app.services.organizations import (
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
|
from fastapi_pagination.limit_offset import LimitOffsetPage
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.services.organizations import OrganizationContext
|
from app.services.organizations import OrganizationContext
|
||||||
@@ -90,7 +91,7 @@ def _parse_since(value: str | None) -> datetime | None:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
return None
|
return None
|
||||||
if parsed.tzinfo is not None:
|
if parsed.tzinfo is not None:
|
||||||
return parsed.astimezone(timezone.utc).replace(tzinfo=None)
|
return parsed.astimezone(UTC).replace(tzinfo=None)
|
||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
@@ -343,7 +344,7 @@ async def list_board_group_memory(
|
|||||||
is_chat: bool | None = IS_CHAT_QUERY,
|
is_chat: bool | None = IS_CHAT_QUERY,
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
ctx: OrganizationContext = ORG_MEMBER_DEP,
|
ctx: OrganizationContext = ORG_MEMBER_DEP,
|
||||||
) -> DefaultLimitOffsetPage[BoardGroupMemoryRead]:
|
) -> LimitOffsetPage[BoardGroupMemoryRead]:
|
||||||
"""List board-group memory entries for a specific group."""
|
"""List board-group memory entries for a specific group."""
|
||||||
await _require_group_access(session, group_id=group_id, ctx=ctx, write=False)
|
await _require_group_access(session, group_id=group_id, ctx=ctx, write=False)
|
||||||
statement = (
|
statement = (
|
||||||
@@ -439,7 +440,7 @@ async def list_board_group_memory_for_board(
|
|||||||
is_chat: bool | None = IS_CHAT_QUERY,
|
is_chat: bool | None = IS_CHAT_QUERY,
|
||||||
board: Board = BOARD_READ_DEP,
|
board: Board = BOARD_READ_DEP,
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
) -> DefaultLimitOffsetPage[BoardGroupMemoryRead]:
|
) -> LimitOffsetPage[BoardGroupMemoryRead]:
|
||||||
"""List memory entries for the board's linked group."""
|
"""List memory entries for the board's linked group."""
|
||||||
group_id = board.board_group_id
|
group_id = board.board_group_id
|
||||||
if group_id is None:
|
if group_id is None:
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ from app.services.organizations import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from fastapi_pagination.limit_offset import LimitOffsetPage
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.models.organization_members import OrganizationMember
|
from app.models.organization_members import OrganizationMember
|
||||||
@@ -103,7 +104,7 @@ async def _require_group_access(
|
|||||||
async def list_board_groups(
|
async def list_board_groups(
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
ctx: OrganizationContext = ORG_MEMBER_DEP,
|
ctx: OrganizationContext = ORG_MEMBER_DEP,
|
||||||
) -> DefaultLimitOffsetPage[BoardGroupRead]:
|
) -> LimitOffsetPage[BoardGroupRead]:
|
||||||
"""List board groups in the active organization."""
|
"""List board groups in the active organization."""
|
||||||
if member_all_boards_read(ctx.member):
|
if member_all_boards_read(ctx.member):
|
||||||
statement = select(BoardGroup).where(
|
statement = select(BoardGroup).where(
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from datetime import datetime, timezone
|
from datetime import UTC, datetime
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@@ -39,6 +39,7 @@ from app.services.mentions import extract_mentions, matches_agent_mention
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
|
from fastapi_pagination.limit_offset import LimitOffsetPage
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.models.boards import Board
|
from app.models.boards import Board
|
||||||
@@ -67,7 +68,7 @@ def _parse_since(value: str | None) -> datetime | None:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
return None
|
return None
|
||||||
if parsed.tzinfo is not None:
|
if parsed.tzinfo is not None:
|
||||||
return parsed.astimezone(timezone.utc).replace(tzinfo=None)
|
return parsed.astimezone(UTC).replace(tzinfo=None)
|
||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
@@ -250,7 +251,7 @@ async def list_board_memory(
|
|||||||
board: Board = BOARD_READ_DEP,
|
board: Board = BOARD_READ_DEP,
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
_actor: ActorContext = ACTOR_DEP,
|
_actor: ActorContext = ACTOR_DEP,
|
||||||
) -> DefaultLimitOffsetPage[BoardMemoryRead]:
|
) -> LimitOffsetPage[BoardMemoryRead]:
|
||||||
"""List board memory entries, optionally filtering chat entries."""
|
"""List board memory entries, optionally filtering chat entries."""
|
||||||
statement = (
|
statement = (
|
||||||
BoardMemory.objects.filter_by(board_id=board.id)
|
BoardMemory.objects.filter_by(board_id=board.id)
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ from app.services.board_snapshot import build_board_snapshot
|
|||||||
from app.services.organizations import OrganizationContext, board_access_filter
|
from app.services.organizations import OrganizationContext, board_access_filter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from fastapi_pagination.limit_offset import LimitOffsetPage
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
router = APIRouter(prefix="/boards", tags=["boards"])
|
router = APIRouter(prefix="/boards", tags=["boards"])
|
||||||
@@ -246,7 +247,7 @@ async def list_boards(
|
|||||||
board_group_id: UUID | None = BOARD_GROUP_ID_QUERY,
|
board_group_id: UUID | None = BOARD_GROUP_ID_QUERY,
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
ctx: OrganizationContext = ORG_MEMBER_DEP,
|
ctx: OrganizationContext = ORG_MEMBER_DEP,
|
||||||
) -> DefaultLimitOffsetPage[BoardRead]:
|
) -> LimitOffsetPage[BoardRead]:
|
||||||
"""List boards visible to the current organization member."""
|
"""List boards visible to the current organization member."""
|
||||||
statement = select(Board).where(board_access_filter(ctx.member, write=False))
|
statement = select(Board).where(board_access_filter(ctx.member, write=False))
|
||||||
if gateway_id is not None:
|
if gateway_id is not None:
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ from app.services.template_sync import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from fastapi_pagination.limit_offset import LimitOffsetPage
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.services.organizations import OrganizationContext
|
from app.services.organizations import OrganizationContext
|
||||||
@@ -224,7 +225,7 @@ async def _ensure_main_agent(
|
|||||||
async def list_gateways(
|
async def list_gateways(
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||||
) -> DefaultLimitOffsetPage[GatewayRead]:
|
) -> LimitOffsetPage[GatewayRead]:
|
||||||
"""List gateways for the caller's organization."""
|
"""List gateways for the caller's organization."""
|
||||||
statement = (
|
statement = (
|
||||||
Gateway.objects.filter_by(organization_id=ctx.organization.id)
|
Gateway.objects.filter_by(organization_id=ctx.organization.id)
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ from typing import Literal
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from sqlalchemy import DateTime, case, cast, func
|
from sqlalchemy import DateTime, case, func
|
||||||
|
from sqlalchemy import cast as sql_cast
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
@@ -152,7 +153,7 @@ async def _query_cycle_time(
|
|||||||
board_ids: list[UUID],
|
board_ids: list[UUID],
|
||||||
) -> DashboardRangeSeries:
|
) -> DashboardRangeSeries:
|
||||||
bucket_col = func.date_trunc(range_spec.bucket, Task.updated_at).label("bucket")
|
bucket_col = func.date_trunc(range_spec.bucket, Task.updated_at).label("bucket")
|
||||||
in_progress = cast(Task.in_progress_at, DateTime)
|
in_progress = sql_cast(Task.in_progress_at, DateTime)
|
||||||
duration_hours = func.extract("epoch", Task.updated_at - in_progress) / 3600.0
|
duration_hours = func.extract("epoch", Task.updated_at - in_progress) / 3600.0
|
||||||
statement = (
|
statement = (
|
||||||
select(bucket_col, func.avg(duration_hours))
|
select(bucket_col, func.avg(duration_hours))
|
||||||
@@ -249,7 +250,7 @@ async def _median_cycle_time_7d(
|
|||||||
) -> float | None:
|
) -> float | None:
|
||||||
now = utcnow()
|
now = utcnow()
|
||||||
start = now - timedelta(days=7)
|
start = now - timedelta(days=7)
|
||||||
in_progress = cast(Task.in_progress_at, DateTime)
|
in_progress = sql_cast(Task.in_progress_at, DateTime)
|
||||||
duration_hours = func.extract("epoch", Task.updated_at - in_progress) / 3600.0
|
duration_hours = func.extract("epoch", Task.updated_at - in_progress) / 3600.0
|
||||||
statement = (
|
statement = (
|
||||||
select(func.percentile_cont(0.5).within_group(duration_hours))
|
select(func.percentile_cont(0.5).within_group(duration_hours))
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import secrets
|
import secrets
|
||||||
from typing import TYPE_CHECKING, Any, Sequence
|
from typing import TYPE_CHECKING, Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
@@ -65,6 +65,9 @@ from app.services.organizations import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from fastapi_pagination.limit_offset import LimitOffsetPage
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.core.auth import AuthContext
|
from app.core.auth import AuthContext
|
||||||
@@ -369,7 +372,7 @@ async def get_my_membership(
|
|||||||
async def list_org_members(
|
async def list_org_members(
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
ctx: OrganizationContext = ORG_MEMBER_DEP,
|
ctx: OrganizationContext = ORG_MEMBER_DEP,
|
||||||
) -> DefaultLimitOffsetPage[OrganizationMemberRead]:
|
) -> LimitOffsetPage[OrganizationMemberRead]:
|
||||||
"""List members for the active organization."""
|
"""List members for the active organization."""
|
||||||
statement = (
|
statement = (
|
||||||
select(OrganizationMember, User)
|
select(OrganizationMember, User)
|
||||||
@@ -542,7 +545,7 @@ async def remove_org_member(
|
|||||||
async def list_org_invites(
|
async def list_org_invites(
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||||
) -> DefaultLimitOffsetPage[OrganizationInviteRead]:
|
) -> LimitOffsetPage[OrganizationInviteRead]:
|
||||||
"""List pending invites for the active organization."""
|
"""List pending invites for the active organization."""
|
||||||
statement = (
|
statement = (
|
||||||
OrganizationInvite.objects.filter_by(organization_id=ctx.organization.id)
|
OrganizationInvite.objects.filter_by(organization_id=ctx.organization.id)
|
||||||
|
|||||||
@@ -3,13 +3,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||||
|
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
|
|
||||||
from app.db.queryset import QuerySet, qs
|
from app.db.queryset import QuerySet, qs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sqlalchemy.orm import Mapped
|
||||||
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
from sqlmodel.sql.expression import SelectOfScalar
|
from sqlmodel.sql.expression import SelectOfScalar
|
||||||
|
|
||||||
@@ -27,11 +29,17 @@ class APIQuerySet(Generic[ModelT]):
|
|||||||
"""Expose the underlying SQL statement for advanced composition."""
|
"""Expose the underlying SQL statement for advanced composition."""
|
||||||
return self.queryset.statement
|
return self.queryset.statement
|
||||||
|
|
||||||
def filter(self, *criteria: object) -> APIQuerySet[ModelT]:
|
def filter(
|
||||||
|
self,
|
||||||
|
*criteria: ColumnElement[bool] | bool,
|
||||||
|
) -> APIQuerySet[ModelT]:
|
||||||
"""Return a new queryset with additional SQL criteria applied."""
|
"""Return a new queryset with additional SQL criteria applied."""
|
||||||
return APIQuerySet(self.queryset.filter(*criteria))
|
return APIQuerySet(self.queryset.filter(*criteria))
|
||||||
|
|
||||||
def order_by(self, *ordering: object) -> APIQuerySet[ModelT]:
|
def order_by(
|
||||||
|
self,
|
||||||
|
*ordering: Mapped[Any] | ColumnElement[Any] | str,
|
||||||
|
) -> APIQuerySet[ModelT]:
|
||||||
"""Return a new queryset with ordering clauses applied."""
|
"""Return a new queryset with ordering clauses applied."""
|
||||||
return APIQuerySet(self.queryset.order_by(*ordering))
|
return APIQuerySet(self.queryset.order_by(*ordering))
|
||||||
|
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ import json
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import UTC, datetime
|
||||||
from typing import TYPE_CHECKING, cast
|
from typing import TYPE_CHECKING
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||||
@@ -67,8 +67,9 @@ from app.services.task_dependencies import (
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import AsyncIterator, Sequence
|
from collections.abc import AsyncIterator, Sequence
|
||||||
|
|
||||||
|
from fastapi_pagination.limit_offset import LimitOffsetPage
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
from sqlmodel.sql.expression import Select, SelectOfScalar
|
from sqlmodel.sql.expression import SelectOfScalar
|
||||||
|
|
||||||
from app.core.auth import AuthContext
|
from app.core.auth import AuthContext
|
||||||
from app.models.users import User
|
from app.models.users import User
|
||||||
@@ -85,6 +86,7 @@ TASK_EVENT_TYPES = {
|
|||||||
SSE_SEEN_MAX = 2000
|
SSE_SEEN_MAX = 2000
|
||||||
TASK_SNIPPET_MAX_LEN = 500
|
TASK_SNIPPET_MAX_LEN = 500
|
||||||
TASK_SNIPPET_TRUNCATED_LEN = 497
|
TASK_SNIPPET_TRUNCATED_LEN = 497
|
||||||
|
TASK_EVENT_ROW_LEN = 2
|
||||||
BOARD_READ_DEP = Depends(get_board_for_actor_read)
|
BOARD_READ_DEP = Depends(get_board_for_actor_read)
|
||||||
ACTOR_DEP = Depends(require_admin_or_agent)
|
ACTOR_DEP = Depends(require_admin_or_agent)
|
||||||
SINCE_QUERY = Query(default=None)
|
SINCE_QUERY = Query(default=None)
|
||||||
@@ -154,7 +156,7 @@ def _parse_since(value: str | None) -> datetime | None:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
return None
|
return None
|
||||||
if parsed.tzinfo is not None:
|
if parsed.tzinfo is not None:
|
||||||
return parsed.astimezone(timezone.utc).replace(tzinfo=None)
|
return parsed.astimezone(UTC).replace(tzinfo=None)
|
||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
@@ -168,6 +170,24 @@ def _coerce_task_items(items: Sequence[object]) -> list[Task]:
|
|||||||
return tasks
|
return tasks
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_task_event_rows(
|
||||||
|
items: Sequence[object],
|
||||||
|
) -> list[tuple[ActivityEvent, Task | None]]:
|
||||||
|
rows: list[tuple[ActivityEvent, Task | None]] = []
|
||||||
|
for item in items:
|
||||||
|
if (
|
||||||
|
isinstance(item, tuple)
|
||||||
|
and len(item) == TASK_EVENT_ROW_LEN
|
||||||
|
and isinstance(item[0], ActivityEvent)
|
||||||
|
and (isinstance(item[1], Task) or item[1] is None)
|
||||||
|
):
|
||||||
|
rows.append((item[0], item[1]))
|
||||||
|
continue
|
||||||
|
msg = "Expected (ActivityEvent, Task | None) rows"
|
||||||
|
raise TypeError(msg)
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
async def _lead_was_mentioned(
|
async def _lead_was_mentioned(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
task: Task,
|
task: Task,
|
||||||
@@ -276,16 +296,16 @@ async def _fetch_task_events(
|
|||||||
)
|
)
|
||||||
if not task_ids:
|
if not task_ids:
|
||||||
return []
|
return []
|
||||||
statement = cast(
|
statement = (
|
||||||
"Select[tuple[ActivityEvent, Task | None]]",
|
|
||||||
select(ActivityEvent, Task)
|
select(ActivityEvent, Task)
|
||||||
.outerjoin(Task, col(ActivityEvent.task_id) == col(Task.id))
|
.outerjoin(Task, col(ActivityEvent.task_id) == col(Task.id))
|
||||||
.where(col(ActivityEvent.task_id).in_(task_ids))
|
.where(col(ActivityEvent.task_id).in_(task_ids))
|
||||||
.where(col(ActivityEvent.event_type).in_(TASK_EVENT_TYPES))
|
.where(col(ActivityEvent.event_type).in_(TASK_EVENT_TYPES))
|
||||||
.where(col(ActivityEvent.created_at) >= since)
|
.where(col(ActivityEvent.created_at) >= since)
|
||||||
.order_by(asc(col(ActivityEvent.created_at))),
|
.order_by(asc(col(ActivityEvent.created_at)))
|
||||||
)
|
)
|
||||||
return list(await session.exec(statement))
|
result = await session.execute(statement)
|
||||||
|
return _coerce_task_event_rows(list(result.tuples().all()))
|
||||||
|
|
||||||
|
|
||||||
def _serialize_comment(event: ActivityEvent) -> dict[str, object]:
|
def _serialize_comment(event: ActivityEvent) -> dict[str, object]:
|
||||||
@@ -718,7 +738,7 @@ async def list_tasks(
|
|||||||
board: Board = BOARD_READ_DEP,
|
board: Board = BOARD_READ_DEP,
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
_actor: ActorContext = ACTOR_DEP,
|
_actor: ActorContext = ACTOR_DEP,
|
||||||
) -> DefaultLimitOffsetPage[TaskRead]:
|
) -> LimitOffsetPage[TaskRead]:
|
||||||
"""List board tasks with optional status and assignment filters."""
|
"""List board tasks with optional status and assignment filters."""
|
||||||
statement = _task_list_statement(
|
statement = _task_list_statement(
|
||||||
board_id=board.id,
|
board_id=board.id,
|
||||||
@@ -914,7 +934,7 @@ async def delete_task(
|
|||||||
async def list_task_comments(
|
async def list_task_comments(
|
||||||
task: Task = TASK_DEP,
|
task: Task = TASK_DEP,
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
) -> DefaultLimitOffsetPage[TaskCommentRead]:
|
) -> LimitOffsetPage[TaskCommentRead]:
|
||||||
"""List comments for a task in chronological order."""
|
"""List comments for a task in chronological order."""
|
||||||
statement = (
|
statement = (
|
||||||
select(ActivityEvent)
|
select(ActivityEvent)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import TYPE_CHECKING, Any, Final, cast
|
from typing import TYPE_CHECKING, Any, Final
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
@@ -81,19 +81,49 @@ def install_error_handling(app: FastAPI) -> None:
|
|||||||
|
|
||||||
app.add_exception_handler(
|
app.add_exception_handler(
|
||||||
RequestValidationError,
|
RequestValidationError,
|
||||||
cast(ExceptionHandler, _request_validation_handler),
|
_request_validation_exception_handler,
|
||||||
)
|
)
|
||||||
app.add_exception_handler(
|
app.add_exception_handler(
|
||||||
ResponseValidationError,
|
ResponseValidationError,
|
||||||
cast(ExceptionHandler, _response_validation_handler),
|
_response_validation_exception_handler,
|
||||||
)
|
)
|
||||||
app.add_exception_handler(
|
app.add_exception_handler(
|
||||||
StarletteHTTPException,
|
StarletteHTTPException,
|
||||||
cast(ExceptionHandler, _http_exception_handler),
|
_http_exception_exception_handler,
|
||||||
)
|
)
|
||||||
app.add_exception_handler(Exception, _unhandled_exception_handler)
|
app.add_exception_handler(Exception, _unhandled_exception_handler)
|
||||||
|
|
||||||
|
|
||||||
|
async def _request_validation_exception_handler(
|
||||||
|
request: Request,
|
||||||
|
exc: Exception,
|
||||||
|
) -> Response:
|
||||||
|
if not isinstance(exc, RequestValidationError):
|
||||||
|
msg = "Expected RequestValidationError"
|
||||||
|
raise TypeError(msg)
|
||||||
|
return await _request_validation_handler(request, exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _response_validation_exception_handler(
|
||||||
|
request: Request,
|
||||||
|
exc: Exception,
|
||||||
|
) -> Response:
|
||||||
|
if not isinstance(exc, ResponseValidationError):
|
||||||
|
msg = "Expected ResponseValidationError"
|
||||||
|
raise TypeError(msg)
|
||||||
|
return await _response_validation_handler(request, exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _http_exception_exception_handler(
|
||||||
|
request: Request,
|
||||||
|
exc: Exception,
|
||||||
|
) -> Response:
|
||||||
|
if not isinstance(exc, StarletteHTTPException):
|
||||||
|
msg = "Expected StarletteHTTPException"
|
||||||
|
raise TypeError(msg)
|
||||||
|
return await _http_exception_handler(request, exc)
|
||||||
|
|
||||||
|
|
||||||
def _get_request_id(request: Request) -> str | None:
|
def _get_request_id(request: Request) -> str | None:
|
||||||
request_id = getattr(request.state, "request_id", None)
|
request_id = getattr(request.state, "request_id", None)
|
||||||
if isinstance(request_id, str) and request_id:
|
if isinstance(request_id, str) and request_id:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from datetime import datetime, timezone
|
from datetime import UTC, datetime
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -128,7 +128,7 @@ class JsonFormatter(logging.Formatter):
|
|||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
"timestamp": datetime.fromtimestamp(
|
"timestamp": datetime.fromtimestamp(
|
||||||
record.created,
|
record.created,
|
||||||
tz=timezone.utc,
|
tz=UTC,
|
||||||
).isoformat(),
|
).isoformat(),
|
||||||
"level": record.levelname,
|
"level": record.levelname,
|
||||||
"logger": record.name,
|
"logger": record.name,
|
||||||
@@ -153,6 +153,7 @@ class JsonFormatter(logging.Formatter):
|
|||||||
class KeyValueFormatter(logging.Formatter):
|
class KeyValueFormatter(logging.Formatter):
|
||||||
"""Formatter that appends extra fields as `key=value` pairs."""
|
"""Formatter that appends extra fields as `key=value` pairs."""
|
||||||
|
|
||||||
|
# noinspection PyMethodMayBeStatic
|
||||||
def format(self, record: logging.LogRecord) -> str:
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
"""Render a log line with appended non-standard record fields."""
|
"""Render a log line with appended non-standard record fields."""
|
||||||
base = super().format(record)
|
base = super().format(record)
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ async def save(
|
|||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
async def delete(session: AsyncSession, obj: ModelT, *, commit: bool = True) -> None:
|
async def delete(session: AsyncSession, obj: SQLModel, *, commit: bool = True) -> None:
|
||||||
"""Delete an object with optional commit."""
|
"""Delete an object with optional commit."""
|
||||||
await session.delete(obj)
|
await session.delete(obj)
|
||||||
if commit:
|
if commit:
|
||||||
|
|||||||
@@ -3,19 +3,23 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Awaitable, Callable, Sequence
|
from collections.abc import Awaitable, Callable, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
from typing import TYPE_CHECKING, Any, TypeVar
|
||||||
|
|
||||||
from fastapi_pagination.ext.sqlalchemy import paginate as _paginate
|
from fastapi_pagination.ext.sqlalchemy import paginate as _paginate
|
||||||
|
|
||||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from fastapi_pagination.limit_offset import LimitOffsetPage
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
from sqlmodel.sql.expression import Select, SelectOfScalar
|
from sqlmodel.sql.expression import Select, SelectOfScalar
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
Transformer = Callable[[Sequence[Any]], Sequence[Any] | Awaitable[Sequence[Any]]]
|
Transformer = Callable[
|
||||||
|
[Sequence[Any]],
|
||||||
|
Sequence[Any] | Awaitable[Sequence[Any]],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
async def paginate(
|
async def paginate(
|
||||||
@@ -23,12 +27,7 @@ async def paginate(
|
|||||||
statement: Select[Any] | SelectOfScalar[Any],
|
statement: Select[Any] | SelectOfScalar[Any],
|
||||||
*,
|
*,
|
||||||
transformer: Transformer | None = None,
|
transformer: Transformer | None = None,
|
||||||
) -> DefaultLimitOffsetPage[T]:
|
) -> LimitOffsetPage[T]:
|
||||||
"""Execute a paginated query and cast to the project page type alias."""
|
"""Execute a paginated query and cast to the project page type alias."""
|
||||||
# fastapi-pagination is not fully typed (it returns Any), but response_model
|
page = await _paginate(session, statement, transformer=transformer)
|
||||||
# validation ensures runtime correctness. Centralize casts here to keep strict
|
return DefaultLimitOffsetPage[T].model_validate(page)
|
||||||
# mypy clean.
|
|
||||||
return cast(
|
|
||||||
DefaultLimitOffsetPage[T],
|
|
||||||
await _paginate(session, statement, transformer=transformer),
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ from app.db.queryset import QuerySet, qs
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
|
|
||||||
ModelT = TypeVar("ModelT", bound=SQLModel)
|
ModelT = TypeVar("ModelT", bound=SQLModel)
|
||||||
|
|
||||||
|
|
||||||
@@ -31,11 +33,17 @@ class ModelManager(Generic[ModelT]):
|
|||||||
"""Return a queryset that yields no rows."""
|
"""Return a queryset that yields no rows."""
|
||||||
return qs(self.model).filter(false())
|
return qs(self.model).filter(false())
|
||||||
|
|
||||||
def filter(self, *criteria: object) -> QuerySet[ModelT]:
|
def filter(
|
||||||
|
self,
|
||||||
|
*criteria: ColumnElement[bool] | bool,
|
||||||
|
) -> QuerySet[ModelT]:
|
||||||
"""Return queryset filtered by SQL criteria expressions."""
|
"""Return queryset filtered by SQL criteria expressions."""
|
||||||
return self.all().filter(*criteria)
|
return self.all().filter(*criteria)
|
||||||
|
|
||||||
def where(self, *criteria: object) -> QuerySet[ModelT]:
|
def where(
|
||||||
|
self,
|
||||||
|
*criteria: ColumnElement[bool] | bool,
|
||||||
|
) -> QuerySet[ModelT]:
|
||||||
"""Alias for `filter`."""
|
"""Alias for `filter`."""
|
||||||
return self.filter(*criteria)
|
return self.filter(*criteria)
|
||||||
|
|
||||||
@@ -76,6 +84,7 @@ class ModelManager(Generic[ModelT]):
|
|||||||
class ManagerDescriptor(Generic[ModelT]):
|
class ManagerDescriptor(Generic[ModelT]):
|
||||||
"""Descriptor that exposes a model-bound `ModelManager` as `.objects`."""
|
"""Descriptor that exposes a model-bound `ModelManager` as `.objects`."""
|
||||||
|
|
||||||
|
# noinspection PyMethodMayBeStatic
|
||||||
def __get__(self, instance: object, owner: type[ModelT]) -> ModelManager[ModelT]:
|
def __get__(self, instance: object, owner: type[ModelT]) -> ModelManager[ModelT]:
|
||||||
"""Return a fresh manager bound to the owning model class."""
|
"""Return a fresh manager bound to the owning model class."""
|
||||||
return ModelManager(owner)
|
return ModelManager(owner)
|
||||||
|
|||||||
@@ -3,11 +3,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, replace
|
from dataclasses import dataclass, replace
|
||||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
|
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||||
|
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sqlalchemy.orm import Mapped
|
||||||
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
from sqlmodel.sql.expression import SelectOfScalar
|
from sqlmodel.sql.expression import SelectOfScalar
|
||||||
|
|
||||||
@@ -20,15 +22,18 @@ class QuerySet(Generic[ModelT]):
|
|||||||
|
|
||||||
statement: SelectOfScalar[ModelT]
|
statement: SelectOfScalar[ModelT]
|
||||||
|
|
||||||
def filter(self, *criteria: object) -> QuerySet[ModelT]:
|
def filter(
|
||||||
|
self,
|
||||||
|
*criteria: ColumnElement[bool] | bool,
|
||||||
|
) -> QuerySet[ModelT]:
|
||||||
"""Return a new queryset with additional SQL criteria."""
|
"""Return a new queryset with additional SQL criteria."""
|
||||||
statement = cast(
|
statement = self.statement.where(*criteria)
|
||||||
"SelectOfScalar[ModelT]",
|
|
||||||
cast(Any, self.statement).where(*criteria),
|
|
||||||
)
|
|
||||||
return replace(self, statement=statement)
|
return replace(self, statement=statement)
|
||||||
|
|
||||||
def where(self, *criteria: object) -> QuerySet[ModelT]:
|
def where(
|
||||||
|
self,
|
||||||
|
*criteria: ColumnElement[bool] | bool,
|
||||||
|
) -> QuerySet[ModelT]:
|
||||||
"""Alias for `filter` to mirror SQLAlchemy naming."""
|
"""Alias for `filter` to mirror SQLAlchemy naming."""
|
||||||
return self.filter(*criteria)
|
return self.filter(*criteria)
|
||||||
|
|
||||||
@@ -37,12 +42,12 @@ class QuerySet(Generic[ModelT]):
|
|||||||
statement = self.statement.filter_by(**kwargs)
|
statement = self.statement.filter_by(**kwargs)
|
||||||
return replace(self, statement=statement)
|
return replace(self, statement=statement)
|
||||||
|
|
||||||
def order_by(self, *ordering: object) -> QuerySet[ModelT]:
|
def order_by(
|
||||||
|
self,
|
||||||
|
*ordering: Mapped[Any] | ColumnElement[Any] | str,
|
||||||
|
) -> QuerySet[ModelT]:
|
||||||
"""Return a new queryset with ordering clauses applied."""
|
"""Return a new queryset with ordering clauses applied."""
|
||||||
statement = cast(
|
statement = self.statement.order_by(*ordering)
|
||||||
"SelectOfScalar[ModelT]",
|
|
||||||
cast(Any, self.statement).order_by(*ordering),
|
|
||||||
)
|
|
||||||
return replace(self, statement=statement)
|
return replace(self, statement=statement)
|
||||||
|
|
||||||
def limit(self, value: int) -> QuerySet[ModelT]:
|
def limit(self, value: int) -> QuerySet[ModelT]:
|
||||||
|
|||||||
@@ -2,11 +2,11 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import anyio
|
|
||||||
from alembic import command
|
from alembic import command
|
||||||
from alembic.config import Config
|
from alembic.config import Config
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
@@ -65,11 +65,11 @@ async def init_db() -> None:
|
|||||||
versions_dir = Path(__file__).resolve().parents[2] / "migrations" / "versions"
|
versions_dir = Path(__file__).resolve().parents[2] / "migrations" / "versions"
|
||||||
if any(versions_dir.glob("*.py")):
|
if any(versions_dir.glob("*.py")):
|
||||||
logger.info("Running migrations on startup")
|
logger.info("Running migrations on startup")
|
||||||
await anyio.to_thread.run_sync(run_migrations)
|
await asyncio.to_thread(run_migrations)
|
||||||
return
|
return
|
||||||
logger.warning("No migration revisions found; falling back to create_all")
|
logger.warning("No migration revisions found; falling back to create_all")
|
||||||
|
|
||||||
async with async_engine.begin() as conn:
|
async with async_engine.connect() as conn, conn.begin():
|
||||||
await conn.run_sync(SQLModel.metadata.create_all)
|
await conn.run_sync(SQLModel.metadata.create_all)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,8 @@ async def _await_response(
|
|||||||
data = json.loads(raw)
|
data = json.loads(raw)
|
||||||
|
|
||||||
if data.get("type") == "res" and data.get("id") == request_id:
|
if data.get("type") == "res" and data.get("id") == request_id:
|
||||||
if data.get("ok") is False:
|
ok = data.get("ok")
|
||||||
|
if ok is not None and not ok:
|
||||||
error = data.get("error", {}).get("message", "Gateway error")
|
error = data.get("error", {}).get("message", "Gateway error")
|
||||||
raise OpenClawGatewayError(error)
|
raise OpenClawGatewayError(error)
|
||||||
return data.get("payload")
|
return data.get("payload")
|
||||||
@@ -135,14 +136,14 @@ async def openclaw_call(
|
|||||||
first_message = None
|
first_message = None
|
||||||
try:
|
try:
|
||||||
first_message = await asyncio.wait_for(ws.recv(), timeout=2)
|
first_message = await asyncio.wait_for(ws.recv(), timeout=2)
|
||||||
except asyncio.TimeoutError:
|
except TimeoutError:
|
||||||
first_message = None
|
first_message = None
|
||||||
await _ensure_connected(ws, first_message, config)
|
await _ensure_connected(ws, first_message, config)
|
||||||
return await _send_request(ws, method, params)
|
return await _send_request(ws, method, params)
|
||||||
except OpenClawGatewayError:
|
except OpenClawGatewayError:
|
||||||
raise
|
raise
|
||||||
except (
|
except (
|
||||||
asyncio.TimeoutError,
|
TimeoutError,
|
||||||
ConnectionError,
|
ConnectionError,
|
||||||
OSError,
|
OSError,
|
||||||
ValueError,
|
ValueError,
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
|
|||||||
class ActivityEvent(QueryModel, table=True):
|
class ActivityEvent(QueryModel, table=True):
|
||||||
"""Discrete activity event tied to tasks and agents."""
|
"""Discrete activity event tied to tasks and agents."""
|
||||||
|
|
||||||
__tablename__ = "activity_events"
|
__tablename__ = "activity_events" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
event_type: str = Field(index=True)
|
event_type: str = Field(index=True)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
|
|||||||
class Agent(QueryModel, table=True):
|
class Agent(QueryModel, table=True):
|
||||||
"""Agent configuration and lifecycle state persisted in the database."""
|
"""Agent configuration and lifecycle state persisted in the database."""
|
||||||
|
|
||||||
__tablename__ = "agents"
|
__tablename__ = "agents" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
board_id: UUID | None = Field(default=None, foreign_key="boards.id", index=True)
|
board_id: UUID | None = Field(default=None, foreign_key="boards.id", index=True)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
|
|||||||
class Approval(QueryModel, table=True):
|
class Approval(QueryModel, table=True):
|
||||||
"""Approval request and decision metadata for gated operations."""
|
"""Approval request and decision metadata for gated operations."""
|
||||||
|
|
||||||
__tablename__ = "approvals"
|
__tablename__ = "approvals" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
board_id: UUID = Field(foreign_key="boards.id", index=True)
|
board_id: UUID = Field(foreign_key="boards.id", index=True)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
|
|||||||
class BoardGroupMemory(QueryModel, table=True):
|
class BoardGroupMemory(QueryModel, table=True):
|
||||||
"""Persisted memory items associated with a board group."""
|
"""Persisted memory items associated with a board group."""
|
||||||
|
|
||||||
__tablename__ = "board_group_memory"
|
__tablename__ = "board_group_memory" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
board_group_id: UUID = Field(foreign_key="board_groups.id", index=True)
|
board_group_id: UUID = Field(foreign_key="board_groups.id", index=True)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
|
|||||||
class BoardGroup(TenantScoped, table=True):
|
class BoardGroup(TenantScoped, table=True):
|
||||||
"""Logical grouping container for boards within an organization."""
|
"""Logical grouping container for boards within an organization."""
|
||||||
|
|
||||||
__tablename__ = "board_groups"
|
__tablename__ = "board_groups" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
organization_id: UUID = Field(foreign_key="organizations.id", index=True)
|
organization_id: UUID = Field(foreign_key="organizations.id", index=True)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
|
|||||||
class BoardMemory(QueryModel, table=True):
|
class BoardMemory(QueryModel, table=True):
|
||||||
"""Persisted memory item attached directly to a board."""
|
"""Persisted memory item attached directly to a board."""
|
||||||
|
|
||||||
__tablename__ = "board_memory"
|
__tablename__ = "board_memory" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
board_id: UUID = Field(foreign_key="boards.id", index=True)
|
board_id: UUID = Field(foreign_key="boards.id", index=True)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
|
|||||||
class BoardOnboardingSession(QueryModel, table=True):
|
class BoardOnboardingSession(QueryModel, table=True):
|
||||||
"""Persisted onboarding conversation and draft goal data for a board."""
|
"""Persisted onboarding conversation and draft goal data for a board."""
|
||||||
|
|
||||||
__tablename__ = "board_onboarding_sessions"
|
__tablename__ = "board_onboarding_sessions" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
board_id: UUID = Field(foreign_key="boards.id", index=True)
|
board_id: UUID = Field(foreign_key="boards.id", index=True)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
|
|||||||
class Board(TenantScoped, table=True):
|
class Board(TenantScoped, table=True):
|
||||||
"""Primary board entity grouping tasks, agents, and goal metadata."""
|
"""Primary board entity grouping tasks, agents, and goal metadata."""
|
||||||
|
|
||||||
__tablename__ = "boards"
|
__tablename__ = "boards" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
organization_id: UUID = Field(foreign_key="organizations.id", index=True)
|
organization_id: UUID = Field(foreign_key="organizations.id", index=True)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
|
|||||||
class Gateway(QueryModel, table=True):
|
class Gateway(QueryModel, table=True):
|
||||||
"""Configured external gateway endpoint and authentication settings."""
|
"""Configured external gateway endpoint and authentication settings."""
|
||||||
|
|
||||||
__tablename__ = "gateways"
|
__tablename__ = "gateways" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
organization_id: UUID = Field(foreign_key="organizations.id", index=True)
|
organization_id: UUID = Field(foreign_key="organizations.id", index=True)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
|
|||||||
class OrganizationBoardAccess(QueryModel, table=True):
|
class OrganizationBoardAccess(QueryModel, table=True):
|
||||||
"""Member-specific board permissions within an organization."""
|
"""Member-specific board permissions within an organization."""
|
||||||
|
|
||||||
__tablename__ = "organization_board_access"
|
__tablename__ = "organization_board_access" # pyright: ignore[reportAssignmentType]
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint(
|
UniqueConstraint(
|
||||||
"organization_member_id",
|
"organization_member_id",
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
|
|||||||
class OrganizationInviteBoardAccess(QueryModel, table=True):
|
class OrganizationInviteBoardAccess(QueryModel, table=True):
|
||||||
"""Invite-specific board permissions applied after invite acceptance."""
|
"""Invite-specific board permissions applied after invite acceptance."""
|
||||||
|
|
||||||
__tablename__ = "organization_invite_board_access"
|
__tablename__ = "organization_invite_board_access" # pyright: ignore[reportAssignmentType]
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint(
|
UniqueConstraint(
|
||||||
"organization_invite_id",
|
"organization_invite_id",
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
|
|||||||
class OrganizationInvite(QueryModel, table=True):
|
class OrganizationInvite(QueryModel, table=True):
|
||||||
"""Invitation record granting prospective organization access."""
|
"""Invitation record granting prospective organization access."""
|
||||||
|
|
||||||
__tablename__ = "organization_invites"
|
__tablename__ = "organization_invites" # pyright: ignore[reportAssignmentType]
|
||||||
__table_args__ = (UniqueConstraint("token", name="uq_org_invites_token"),)
|
__table_args__ = (UniqueConstraint("token", name="uq_org_invites_token"),)
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
|
|||||||
class OrganizationMember(QueryModel, table=True):
|
class OrganizationMember(QueryModel, table=True):
|
||||||
"""Membership row linking a user to an organization and permissions."""
|
"""Membership row linking a user to an organization and permissions."""
|
||||||
|
|
||||||
__tablename__ = "organization_members"
|
__tablename__ = "organization_members" # pyright: ignore[reportAssignmentType]
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint(
|
UniqueConstraint(
|
||||||
"organization_id",
|
"organization_id",
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
|
|||||||
class Organization(QueryModel, table=True):
|
class Organization(QueryModel, table=True):
|
||||||
"""Top-level organization tenant record."""
|
"""Top-level organization tenant record."""
|
||||||
|
|
||||||
__tablename__ = "organizations"
|
__tablename__ = "organizations" # pyright: ignore[reportAssignmentType]
|
||||||
__table_args__ = (UniqueConstraint("name", name="uq_organizations_name"),)
|
__table_args__ = (UniqueConstraint("name", name="uq_organizations_name"),)
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
|
|||||||
class TaskDependency(TenantScoped, table=True):
|
class TaskDependency(TenantScoped, table=True):
|
||||||
"""Directed dependency edge between two tasks in the same board."""
|
"""Directed dependency edge between two tasks in the same board."""
|
||||||
|
|
||||||
__tablename__ = "task_dependencies"
|
__tablename__ = "task_dependencies" # pyright: ignore[reportAssignmentType]
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint(
|
UniqueConstraint(
|
||||||
"task_id",
|
"task_id",
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
|
|||||||
class TaskFingerprint(QueryModel, table=True):
|
class TaskFingerprint(QueryModel, table=True):
|
||||||
"""Hashed task-content fingerprint associated with a board and task."""
|
"""Hashed task-content fingerprint associated with a board and task."""
|
||||||
|
|
||||||
__tablename__ = "task_fingerprints"
|
__tablename__ = "task_fingerprints" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
board_id: UUID = Field(foreign_key="boards.id", index=True)
|
board_id: UUID = Field(foreign_key="boards.id", index=True)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ RUNTIME_ANNOTATION_TYPES = (datetime,)
|
|||||||
class Task(TenantScoped, table=True):
|
class Task(TenantScoped, table=True):
|
||||||
"""Board-scoped task entity with ownership, status, and timing fields."""
|
"""Board-scoped task entity with ownership, status, and timing fields."""
|
||||||
|
|
||||||
__tablename__ = "tasks"
|
__tablename__ = "tasks" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
board_id: UUID | None = Field(default=None, foreign_key="boards.id", index=True)
|
board_id: UUID | None = Field(default=None, foreign_key="boards.id", index=True)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from app.models.base import QueryModel
|
|||||||
class User(QueryModel, table=True):
|
class User(QueryModel, table=True):
|
||||||
"""Application user account and profile attributes."""
|
"""Application user account and profile attributes."""
|
||||||
|
|
||||||
__tablename__ = "users"
|
__tablename__ = "users" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
clerk_user_id: str = Field(index=True, unique=True)
|
clerk_user_id: str = Field(index=True, unique=True)
|
||||||
|
|||||||
@@ -34,11 +34,13 @@ class BoardBase(SQLModel):
|
|||||||
class BoardCreate(BoardBase):
|
class BoardCreate(BoardBase):
|
||||||
"""Payload for creating a board."""
|
"""Payload for creating a board."""
|
||||||
|
|
||||||
gateway_id: UUID
|
gateway_id: UUID | None = None
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_goal_fields(self) -> Self:
|
def validate_goal_fields(self) -> Self:
|
||||||
"""Require goal details when creating a confirmed goal board."""
|
"""Require gateway and goal details when creating a confirmed goal board."""
|
||||||
|
if self.gateway_id is None:
|
||||||
|
raise ValueError(_ERR_GATEWAY_REQUIRED)
|
||||||
if (
|
if (
|
||||||
self.board_type == "goal"
|
self.board_type == "goal"
|
||||||
and self.goal_confirmed
|
and self.goal_confirmed
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TypeVar
|
from typing import TYPE_CHECKING, TypeVar
|
||||||
|
|
||||||
from fastapi import Query
|
from fastapi import Query
|
||||||
from fastapi_pagination.customization import CustomizedPage, UseParamsFields
|
from fastapi_pagination.customization import CustomizedPage, UseParamsFields
|
||||||
@@ -14,10 +14,15 @@ T = TypeVar("T")
|
|||||||
# Project-wide default pagination response model.
|
# Project-wide default pagination response model.
|
||||||
# - Keep `limit` / `offset` naming (matches existing API conventions).
|
# - Keep `limit` / `offset` naming (matches existing API conventions).
|
||||||
# - Cap list endpoints to 200 items per request (matches prior route-level constraints).
|
# - Cap list endpoints to 200 items per request (matches prior route-level constraints).
|
||||||
DefaultLimitOffsetPage = CustomizedPage[
|
if TYPE_CHECKING:
|
||||||
LimitOffsetPage[T],
|
# Type checkers treat this as a normal generic page type.
|
||||||
UseParamsFields(
|
DefaultLimitOffsetPage = LimitOffsetPage
|
||||||
limit=Query(200, ge=1, le=200),
|
else:
|
||||||
offset=Query(0, ge=0),
|
# Runtime uses project-default query param bounds for all list endpoints.
|
||||||
),
|
DefaultLimitOffsetPage = CustomizedPage[
|
||||||
]
|
LimitOffsetPage[T],
|
||||||
|
UseParamsFields(
|
||||||
|
limit=Query(200, ge=1, le=200),
|
||||||
|
offset=Query(0, ge=0),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|||||||
@@ -738,7 +738,7 @@ def _should_include_bootstrap(
|
|||||||
if not existing_files:
|
if not existing_files:
|
||||||
return False
|
return False
|
||||||
entry = existing_files.get("BOOTSTRAP.md")
|
entry = existing_files.get("BOOTSTRAP.md")
|
||||||
return not (entry and entry.get("missing") is True)
|
return not bool(entry and entry.get("missing"))
|
||||||
|
|
||||||
|
|
||||||
async def _set_agent_files(
|
async def _set_agent_files(
|
||||||
@@ -753,7 +753,7 @@ async def _set_agent_files(
|
|||||||
continue
|
continue
|
||||||
if name in PRESERVE_AGENT_EDITABLE_FILES:
|
if name in PRESERVE_AGENT_EDITABLE_FILES:
|
||||||
entry = existing_files.get(name)
|
entry = existing_files.get(name)
|
||||||
if entry and entry.get("missing") is not True:
|
if entry and not bool(entry.get("missing")):
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
await openclaw_call(
|
await openclaw_call(
|
||||||
|
|||||||
@@ -117,20 +117,20 @@ def _has_cycle(nodes: Sequence[UUID], edges: Mapping[UUID, set[UUID]]) -> bool:
|
|||||||
visited: set[UUID] = set()
|
visited: set[UUID] = set()
|
||||||
in_stack: set[UUID] = set()
|
in_stack: set[UUID] = set()
|
||||||
|
|
||||||
def dfs(node: UUID) -> bool:
|
def dfs(current: UUID) -> bool:
|
||||||
if node in in_stack:
|
if current in in_stack:
|
||||||
return True
|
return True
|
||||||
if node in visited:
|
if current in visited:
|
||||||
return False
|
return False
|
||||||
visited.add(node)
|
visited.add(current)
|
||||||
in_stack.add(node)
|
in_stack.add(current)
|
||||||
for nxt in edges.get(node, set()):
|
for nxt in edges.get(current, set()):
|
||||||
if dfs(nxt):
|
if dfs(nxt):
|
||||||
return True
|
return True
|
||||||
in_stack.remove(node)
|
in_stack.remove(current)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return any(dfs(node) for node in nodes)
|
return any(dfs(start_node) for start_node in nodes)
|
||||||
|
|
||||||
|
|
||||||
async def validate_dependency_update(
|
async def validate_dependency_update(
|
||||||
|
|||||||
@@ -132,8 +132,8 @@ class _GatewayBackoff:
|
|||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self._delay_s = self._base_delay_s
|
self._delay_s = self._base_delay_s
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
async def _attempt(
|
async def _attempt(
|
||||||
self,
|
|
||||||
fn: Callable[[], Awaitable[T]],
|
fn: Callable[[], Awaitable[T]],
|
||||||
) -> tuple[T | None, OpenClawGatewayError | None]:
|
) -> tuple[T | None, OpenClawGatewayError | None]:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -4,16 +4,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from types import SimpleNamespace
|
from typing import Any
|
||||||
from typing import TYPE_CHECKING, cast
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.api import board_groups
|
from app.api import board_groups
|
||||||
|
from app.models.organization_members import OrganizationMember
|
||||||
if TYPE_CHECKING:
|
from app.models.organizations import Organization
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from app.services.organizations import OrganizationContext
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -47,12 +46,20 @@ async def test_delete_board_group_cleans_group_memory_first(
|
|||||||
_fake_require_group_access,
|
_fake_require_group_access,
|
||||||
)
|
)
|
||||||
|
|
||||||
session = _FakeSession()
|
session: Any = _FakeSession()
|
||||||
ctx = SimpleNamespace(member=object())
|
org_id = uuid4()
|
||||||
|
ctx = OrganizationContext(
|
||||||
|
organization=Organization(id=org_id, name=f"org-{org_id}"),
|
||||||
|
member=OrganizationMember(
|
||||||
|
organization_id=org_id,
|
||||||
|
user_id=uuid4(),
|
||||||
|
role="admin",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
await board_groups.delete_board_group(
|
await board_groups.delete_board_group(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
session=cast("AsyncSession", session),
|
session=session,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, cast
|
from typing import Any
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -12,9 +12,6 @@ import pytest
|
|||||||
from app.api import boards
|
from app.api import boards
|
||||||
from app.models.boards import Board
|
from app.models.boards import Board
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
_NO_EXEC_RESULTS_ERROR = "No more exec_results left for session.exec"
|
_NO_EXEC_RESULTS_ERROR = "No more exec_results left for session.exec"
|
||||||
|
|
||||||
|
|
||||||
@@ -47,7 +44,7 @@ class _FakeSession:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_board_cleans_org_board_access_rows() -> None:
|
async def test_delete_board_cleans_org_board_access_rows() -> None:
|
||||||
"""Deleting a board should clear org-board access rows before commit."""
|
"""Deleting a board should clear org-board access rows before commit."""
|
||||||
session = _FakeSession(exec_results=[[], []])
|
session: Any = _FakeSession(exec_results=[[], []])
|
||||||
board = Board(
|
board = Board(
|
||||||
id=uuid4(),
|
id=uuid4(),
|
||||||
organization_id=uuid4(),
|
organization_id=uuid4(),
|
||||||
@@ -57,7 +54,7 @@ async def test_delete_board_cleans_org_board_access_rows() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
await boards.delete_board(
|
await boards.delete_board(
|
||||||
session=cast("AsyncSession", session),
|
session=session,
|
||||||
board=board,
|
board=board,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,8 @@ async def test_get_session_rolls_back_on_dependency_error(monkeypatch: pytest.Mo
|
|||||||
class _FakeDependencySession:
|
class _FakeDependencySession:
|
||||||
rollbacks: int = 0
|
rollbacks: int = 0
|
||||||
|
|
||||||
def in_transaction(self) -> bool:
|
@staticmethod
|
||||||
|
def in_transaction() -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def rollback(self) -> None:
|
async def rollback(self) -> None:
|
||||||
@@ -89,16 +90,19 @@ async def test_create_rolls_back_when_commit_fails() -> None:
|
|||||||
def add(self, value: Any) -> None:
|
def add(self, value: Any) -> None:
|
||||||
self.added.append(value)
|
self.added.append(value)
|
||||||
|
|
||||||
async def flush(self) -> None:
|
@staticmethod
|
||||||
|
async def flush() -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def commit(self) -> None:
|
@staticmethod
|
||||||
|
async def commit() -> None:
|
||||||
raise _CommitError("commit failed")
|
raise _CommitError("commit failed")
|
||||||
|
|
||||||
async def rollback(self) -> None:
|
async def rollback(self) -> None:
|
||||||
self.rollback_calls += 1
|
self.rollback_calls += 1
|
||||||
|
|
||||||
async def refresh(self, _value: Any) -> None:
|
@staticmethod
|
||||||
|
async def refresh(_value: Any) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
session = _FailCommitSession()
|
session = _FailCommitSession()
|
||||||
@@ -124,7 +128,8 @@ async def test_delete_where_rolls_back_when_commit_fails() -> None:
|
|||||||
self.exec_calls += 1
|
self.exec_calls += 1
|
||||||
return SimpleNamespace(rowcount=3)
|
return SimpleNamespace(rowcount=3)
|
||||||
|
|
||||||
async def commit(self) -> None:
|
@staticmethod
|
||||||
|
async def commit() -> None:
|
||||||
raise _CommitError("commit failed")
|
raise _CommitError("commit failed")
|
||||||
|
|
||||||
async def rollback(self) -> None:
|
async def rollback(self) -> None:
|
||||||
|
|||||||
@@ -4,17 +4,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from types import SimpleNamespace
|
from typing import Any
|
||||||
from typing import TYPE_CHECKING, cast
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
|
|
||||||
from app.api import organizations
|
from app.api import organizations
|
||||||
|
from app.models.organization_members import OrganizationMember
|
||||||
if TYPE_CHECKING:
|
from app.models.organizations import Organization
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from app.services.organizations import OrganizationContext
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -35,15 +34,19 @@ class _FakeSession:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_my_org_cleans_dependents_before_organization_delete() -> None:
|
async def test_delete_my_org_cleans_dependents_before_organization_delete() -> None:
|
||||||
"""Delete flow should remove dependent rows before the organization row."""
|
"""Delete flow should remove dependent rows before the organization row."""
|
||||||
session = _FakeSession()
|
session: Any = _FakeSession()
|
||||||
org_id = uuid4()
|
org_id = uuid4()
|
||||||
ctx = SimpleNamespace(
|
ctx = OrganizationContext(
|
||||||
organization=SimpleNamespace(id=org_id),
|
organization=Organization(id=org_id, name=f"org-{org_id}"),
|
||||||
member=SimpleNamespace(role="owner"),
|
member=OrganizationMember(
|
||||||
|
organization_id=org_id,
|
||||||
|
user_id=uuid4(),
|
||||||
|
role="owner",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
await organizations.delete_my_org(
|
await organizations.delete_my_org(
|
||||||
session=cast("AsyncSession", session),
|
session=session,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -77,15 +80,20 @@ async def test_delete_my_org_cleans_dependents_before_organization_delete() -> N
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_my_org_requires_owner_role() -> None:
|
async def test_delete_my_org_requires_owner_role() -> None:
|
||||||
"""Delete flow should reject non-owner members with HTTP 403."""
|
"""Delete flow should reject non-owner members with HTTP 403."""
|
||||||
session = _FakeSession()
|
session: Any = _FakeSession()
|
||||||
ctx = SimpleNamespace(
|
org_id = uuid4()
|
||||||
organization=SimpleNamespace(id=uuid4()),
|
ctx = OrganizationContext(
|
||||||
member=SimpleNamespace(role="admin"),
|
organization=Organization(id=org_id, name=f"org-{org_id}"),
|
||||||
|
member=OrganizationMember(
|
||||||
|
organization_id=org_id,
|
||||||
|
user_id=uuid4(),
|
||||||
|
role="admin",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
await organizations.delete_my_org(
|
await organizations.delete_my_org(
|
||||||
session=cast("AsyncSession", session),
|
session=session,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -3,16 +3,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
|
|
||||||
from app.api import organizations
|
from app.api import organizations
|
||||||
from app.models.organization_members import OrganizationMember
|
from app.models.organization_members import OrganizationMember
|
||||||
|
from app.models.organizations import Organization
|
||||||
from app.models.users import User
|
from app.models.users import User
|
||||||
|
from app.services.organizations import OrganizationContext
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -58,6 +59,17 @@ class _FakeSession:
|
|||||||
self.committed += 1
|
self.committed += 1
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ctx(*, org_id: UUID, user_id: UUID, role: str) -> OrganizationContext:
|
||||||
|
return OrganizationContext(
|
||||||
|
organization=Organization(id=org_id, name=f"org-{org_id}"),
|
||||||
|
member=OrganizationMember(
|
||||||
|
organization_id=org_id,
|
||||||
|
user_id=user_id,
|
||||||
|
role=role,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_remove_org_member_deletes_member_access_and_member() -> None:
|
async def test_remove_org_member_deletes_member_access_and_member() -> None:
|
||||||
org_id = uuid4()
|
org_id = uuid4()
|
||||||
@@ -83,10 +95,7 @@ async def test_remove_org_member_deletes_member_access_and_member() -> None:
|
|||||||
_FakeExecResult(first_value=fallback_org_id),
|
_FakeExecResult(first_value=fallback_org_id),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
ctx = SimpleNamespace(
|
ctx = _make_ctx(org_id=org_id, user_id=actor_user_id, role="admin")
|
||||||
organization=SimpleNamespace(id=org_id),
|
|
||||||
member=SimpleNamespace(user_id=actor_user_id, role="admin"),
|
|
||||||
)
|
|
||||||
|
|
||||||
await organizations.remove_org_member(member_id=member_id, session=session, ctx=ctx)
|
await organizations.remove_org_member(member_id=member_id, session=session, ctx=ctx)
|
||||||
|
|
||||||
@@ -109,10 +118,7 @@ async def test_remove_org_member_disallows_self_removal() -> None:
|
|||||||
role="member",
|
role="member",
|
||||||
)
|
)
|
||||||
session = _FakeSession(exec_results=[_FakeExecResult(first_value=member)])
|
session = _FakeSession(exec_results=[_FakeExecResult(first_value=member)])
|
||||||
ctx = SimpleNamespace(
|
ctx = _make_ctx(org_id=org_id, user_id=user_id, role="owner")
|
||||||
organization=SimpleNamespace(id=org_id),
|
|
||||||
member=SimpleNamespace(user_id=user_id, role="owner"),
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
await organizations.remove_org_member(member_id=member.id, session=session, ctx=ctx)
|
await organizations.remove_org_member(member_id=member.id, session=session, ctx=ctx)
|
||||||
@@ -133,10 +139,7 @@ async def test_remove_org_member_requires_owner_to_remove_owner() -> None:
|
|||||||
role="owner",
|
role="owner",
|
||||||
)
|
)
|
||||||
session = _FakeSession(exec_results=[_FakeExecResult(first_value=member)])
|
session = _FakeSession(exec_results=[_FakeExecResult(first_value=member)])
|
||||||
ctx = SimpleNamespace(
|
ctx = _make_ctx(org_id=org_id, user_id=uuid4(), role="admin")
|
||||||
organization=SimpleNamespace(id=org_id),
|
|
||||||
member=SimpleNamespace(user_id=uuid4(), role="admin"),
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
await organizations.remove_org_member(member_id=member.id, session=session, ctx=ctx)
|
await organizations.remove_org_member(member_id=member.id, session=session, ctx=ctx)
|
||||||
@@ -162,10 +165,7 @@ async def test_remove_org_member_rejects_removing_last_owner() -> None:
|
|||||||
_FakeExecResult(all_values=[member]),
|
_FakeExecResult(all_values=[member]),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
ctx = SimpleNamespace(
|
ctx = _make_ctx(org_id=org_id, user_id=uuid4(), role="owner")
|
||||||
organization=SimpleNamespace(id=org_id),
|
|
||||||
member=SimpleNamespace(user_id=uuid4(), role="owner"),
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
await organizations.remove_org_member(member_id=member.id, session=session, ctx=ctx)
|
await organizations.remove_org_member(member_id=member.id, session=session, ctx=ctx)
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ async def test_request_id_middleware_passes_through_non_http_scope() -> None:
|
|||||||
|
|
||||||
middleware = RequestIdMiddleware(app)
|
middleware = RequestIdMiddleware(app)
|
||||||
|
|
||||||
scope = {"type": "websocket", "headers": []}
|
request_scope = {"type": "websocket", "headers": []}
|
||||||
await middleware(scope, lambda: None, lambda message: None) # type: ignore[arg-type]
|
await middleware(request_scope, lambda: None, lambda message: None) # type: ignore[arg-type]
|
||||||
|
|
||||||
assert called is True
|
assert called is True
|
||||||
|
|
||||||
@@ -40,11 +40,11 @@ async def test_request_id_middleware_ignores_blank_client_header_and_generates_o
|
|||||||
|
|
||||||
middleware = RequestIdMiddleware(app)
|
middleware = RequestIdMiddleware(app)
|
||||||
|
|
||||||
scope = {
|
request_scope = {
|
||||||
"type": "http",
|
"type": "http",
|
||||||
"headers": [(REQUEST_ID_HEADER.lower().encode("latin-1"), b" ")],
|
"headers": [(REQUEST_ID_HEADER.lower().encode("latin-1"), b" ")],
|
||||||
}
|
}
|
||||||
await middleware(scope, lambda: None, send)
|
await middleware(request_scope, lambda: None, send)
|
||||||
|
|
||||||
assert isinstance(captured_request_id, str) and captured_request_id
|
assert isinstance(captured_request_id, str) and captured_request_id
|
||||||
# Header should reflect the generated id, not the blank one.
|
# Header should reflect the generated id, not the blank one.
|
||||||
@@ -78,8 +78,8 @@ async def test_request_id_middleware_does_not_duplicate_existing_header() -> Non
|
|||||||
|
|
||||||
middleware = RequestIdMiddleware(app)
|
middleware = RequestIdMiddleware(app)
|
||||||
|
|
||||||
scope = {"type": "http", "headers": []}
|
request_scope = {"type": "http", "headers": []}
|
||||||
await middleware(scope, lambda: None, send)
|
await middleware(request_scope, lambda: None, send)
|
||||||
|
|
||||||
assert sent_start is True
|
assert sent_start is True
|
||||||
assert start_headers is not None
|
assert start_headers is not None
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from app.services import task_dependencies as td
|
|||||||
async def _make_engine() -> AsyncEngine:
|
async def _make_engine() -> AsyncEngine:
|
||||||
# Single shared in-memory db per engine.
|
# Single shared in-memory db per engine.
|
||||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||||
async with engine.begin() as conn:
|
async with engine.connect() as conn, conn.begin():
|
||||||
await conn.run_sync(SQLModel.metadata.create_all)
|
await conn.run_sync(SQLModel.metadata.create_all)
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user