refactor: replace SQLModel with QueryModel in various models and update query methods

This commit is contained in:
Abhimanyu Saharan
2026-02-09 02:04:14 +05:30
parent e19e47106b
commit 228b99bc9b
40 changed files with 413 additions and 419 deletions

View File

@@ -102,7 +102,7 @@ def _guard_board_access(agent_ctx: AgentAuthContext, board: Board) -> None:
async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig:
if not board.gateway_id:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
gateway = await session.get(Gateway, board.gateway_id)
gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway is None or not gateway.url:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
return GatewayClientConfig(url=gateway.url, token=gateway.token)
@@ -117,9 +117,7 @@ async def _require_gateway_main(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Agent missing session key"
)
gateway = (
await session.exec(select(Gateway).where(col(Gateway.main_session_key) == session_key))
).first()
gateway = await Gateway.objects.filter_by(main_session_key=session_key).first(session)
if gateway is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@@ -139,7 +137,7 @@ async def _require_gateway_board(
gateway: Gateway,
board_id: UUID | str,
) -> Board:
board = await session.get(Board, board_id)
board = await Board.objects.by_id(board_id).first(session)
if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found")
if board.gateway_id != gateway.id:
@@ -254,7 +252,7 @@ async def create_task(
},
)
if task.assigned_agent_id:
agent = await session.get(Agent, task.assigned_agent_id)
agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if agent.is_board_lead:
@@ -286,7 +284,7 @@ async def create_task(
)
await session.commit()
if task.assigned_agent_id:
assigned_agent = await session.get(Agent, task.assigned_agent_id)
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
if assigned_agent:
await tasks_api._notify_agent_on_task_assign(
session=session,
@@ -466,7 +464,7 @@ async def nudge_agent(
_guard_board_access(agent_ctx, board)
if not agent_ctx.agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
target = await session.get(Agent, agent_id)
target = await Agent.objects.by_id(agent_id).first(session)
if target is None or (target.board_id and target.board_id != board.id):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if not target.openclaw_session_id:
@@ -528,7 +526,7 @@ async def get_agent_soul(
_guard_board_access(agent_ctx, board)
if not agent_ctx.agent.is_board_lead and str(agent_ctx.agent.id) != agent_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
target = await session.get(Agent, agent_id)
target = await Agent.objects.by_id(agent_id).first(session)
if target is None or (target.board_id and target.board_id != board.id):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
config = await _gateway_config(session, board)
@@ -566,7 +564,7 @@ async def update_agent_soul(
_guard_board_access(agent_ctx, board)
if not agent_ctx.agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
target = await session.get(Agent, agent_id)
target = await Agent.objects.by_id(agent_id).first(session)
if target is None or (target.board_id and target.board_id != board.id):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
config = await _gateway_config(session, board)
@@ -629,7 +627,7 @@ async def ask_user_via_gateway_main(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Board is not attached to a gateway",
)
gateway = await session.get(Gateway, board.gateway_id)
gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway is None or not gateway.url:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -689,9 +687,7 @@ async def ask_user_via_gateway_main(
agent_id=agent_ctx.agent.id,
)
main_agent = (
await session.exec(select(Agent).where(col(Agent.openclaw_session_id) == main_session_key))
).first()
main_agent = await Agent.objects.filter_by(openclaw_session_id=main_session_key).first(session)
await session.commit()

View File

@@ -109,7 +109,7 @@ async def _require_board(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="board_id is required",
)
board = await session.get(Board, board_id)
board = await Board.objects.by_id(board_id).first(session)
if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found")
if user is not None:
@@ -125,7 +125,7 @@ async def _require_gateway(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Board gateway_id is required",
)
gateway = await session.get(Gateway, board.gateway_id)
gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway is None:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -182,9 +182,7 @@ async def _find_gateway_for_main_session(
) -> Gateway | None:
if not session_key:
return None
return (
await session.exec(select(Gateway).where(Gateway.main_session_key == session_key))
).first()
return await Gateway.objects.filter_by(main_session_key=session_key).first(session)
async def _ensure_gateway_session(
@@ -237,7 +235,7 @@ async def _require_user_context(session: AsyncSession, user: User | None) -> Org
member = await get_active_membership(session, user)
if member is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
organization = await session.get(Organization, member.organization_id)
organization = await Organization.objects.by_id(member.organization_id).first(session)
if organization is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return OrganizationContext(organization=organization, member=member)
@@ -258,7 +256,7 @@ async def _require_agent_access(
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return
board = await session.get(Board, agent.board_id)
board = await Board.objects.by_id(agent.board_id).first(session)
if board is None or board.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if not await has_board_access(session, member=ctx.member, board=board, write=write):
@@ -323,7 +321,7 @@ async def list_agents(
if board_id is not None:
statement = statement.where(col(Agent.board_id) == board_id)
if gateway_id is not None:
gateway = await session.get(Gateway, gateway_id)
gateway = await Gateway.objects.by_id(gateway_id).first(session)
if gateway is None or gateway.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
statement = statement.join(Board, col(Agent.board_id) == col(Board.id)).where(
@@ -532,7 +530,7 @@ async def get_agent(
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
) -> AgentRead:
agent = await session.get(Agent, agent_id)
agent = await Agent.objects.by_id(agent_id).first(session)
if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
await _require_agent_access(session, agent=agent, ctx=ctx, write=False)
@@ -549,7 +547,7 @@ async def update_agent(
auth: AuthContext = Depends(get_auth_context),
ctx: OrganizationContext = Depends(require_org_admin),
) -> AgentRead:
agent = await session.get(Agent, agent_id)
agent = await Agent.objects.by_id(agent_id).first(session)
if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
await _require_agent_access(session, agent=agent, ctx=ctx, write=True)
@@ -728,7 +726,7 @@ async def heartbeat_agent(
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> AgentRead:
agent = await session.get(Agent, agent_id)
agent = await Agent.objects.by_id(agent_id).first(session)
if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if actor.actor_type == "agent" and actor.agent and actor.agent.id != agent.id:
@@ -767,7 +765,7 @@ async def heartbeat_or_create_agent(
actor=actor,
)
statement = select(Agent).where(Agent.name == payload.name)
statement = Agent.objects.filter_by(name=payload.name).statement
if payload.board_id is not None:
statement = statement.where(Agent.board_id == payload.board_id)
agent = (await session.exec(statement)).first()
@@ -943,7 +941,7 @@ async def delete_agent(
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
) -> OkResponse:
agent = await session.get(Agent, agent_id)
agent = await Agent.objects.by_id(agent_id).first(session)
if agent is None:
return OkResponse()
await _require_agent_access(session, agent=agent, ctx=ctx, write=True)

View File

@@ -77,9 +77,8 @@ async def _fetch_approval_events(
since: datetime,
) -> list[Approval]:
statement = (
select(Approval)
.where(col(Approval.board_id) == board_id)
.where(
Approval.objects.filter_by(board_id=board_id)
.filter(
or_(
col(Approval.created_at) >= since,
col(Approval.resolved_at) >= since,
@@ -87,7 +86,7 @@ async def _fetch_approval_events(
)
.order_by(asc(col(Approval.created_at)))
)
return list(await session.exec(statement))
return await statement.all(session)
@router.get("", response_model=DefaultLimitOffsetPage[ApprovalRead])
@@ -97,11 +96,11 @@ async def list_approvals(
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> DefaultLimitOffsetPage[ApprovalRead]:
statement = select(Approval).where(col(Approval.board_id) == board.id)
statement = Approval.objects.filter_by(board_id=board.id)
if status_filter:
statement = statement.where(col(Approval.status) == status_filter)
statement = statement.filter(col(Approval.status) == status_filter)
statement = statement.order_by(col(Approval.created_at).desc())
return await paginate(session, statement)
return await paginate(session, statement.statement)
@router.get("/stream")
@@ -207,7 +206,7 @@ async def update_approval(
board: Board = Depends(get_board_for_user_write),
session: AsyncSession = Depends(get_session),
) -> Approval:
approval = await session.get(Approval, approval_id)
approval = await Approval.objects.by_id(approval_id).first(session)
if approval is None or approval.board_id != board.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
updates = payload.model_dump(exclude_unset=True)

View File

@@ -8,7 +8,7 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import func
from sqlmodel import col, select
from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession
from sse_starlette.sse import EventSourceResponse
@@ -71,7 +71,7 @@ def _serialize_memory(memory: BoardGroupMemory) -> dict[str, object]:
async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None:
if board.gateway_id is None:
return None
gateway = await session.get(Gateway, board.gateway_id)
gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway is None or not gateway.url:
return None
return GatewayClientConfig(url=gateway.url, token=gateway.token)
@@ -96,17 +96,17 @@ async def _fetch_memory_events(
is_chat: bool | None = None,
) -> list[BoardGroupMemory]:
statement = (
select(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id) == board_group_id)
BoardGroupMemory.objects.filter_by(board_group_id=board_group_id)
# Old/invalid rows (empty/whitespace-only content) can exist; exclude them to
# satisfy the NonEmptyStr response schema.
.where(func.length(func.trim(col(BoardGroupMemory.content))) > 0)
.filter(func.length(func.trim(col(BoardGroupMemory.content))) > 0)
)
if is_chat is not None:
statement = statement.where(col(BoardGroupMemory.is_chat) == is_chat)
statement = statement.where(col(BoardGroupMemory.created_at) >= since).order_by(
statement = statement.filter(col(BoardGroupMemory.is_chat) == is_chat)
statement = statement.filter(col(BoardGroupMemory.created_at) >= since).order_by(
col(BoardGroupMemory.created_at)
)
return list(await session.exec(statement))
return await statement.all(session)
async def _require_group_access(
@@ -116,7 +116,7 @@ async def _require_group_access(
ctx: OrganizationContext,
write: bool,
) -> BoardGroup:
group = await session.get(BoardGroup, group_id)
group = await BoardGroup.objects.by_id(group_id).first(session)
if group is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if group.organization_id != ctx.member.organization_id:
@@ -127,9 +127,9 @@ async def _require_group_access(
if not write and member_all_boards_read(ctx.member):
return group
board_ids = list(
await session.exec(select(Board.id).where(col(Board.board_group_id) == group_id))
)
board_ids = [
board.id for board in await Board.objects.filter_by(board_group_id=group_id).all(session)
]
if not board_ids:
if is_org_admin(ctx.member):
return group
@@ -156,12 +156,12 @@ async def _notify_group_memory_targets(
is_broadcast = "broadcast" in tags or "all" in mentions
# Fetch group boards + agents.
boards = list(await session.exec(select(Board).where(col(Board.board_group_id) == group.id)))
boards = await Board.objects.filter_by(board_group_id=group.id).all(session)
if not boards:
return
board_by_id = {board.id: board for board in boards}
board_ids = list(board_by_id.keys())
agents = list(await session.exec(select(Agent).where(col(Agent.board_id).in_(board_ids))))
agents = await Agent.objects.by_field_in("board_id", board_ids).all(session)
targets: dict[str, Agent] = {}
for agent in agents:
@@ -242,15 +242,15 @@ async def list_board_group_memory(
) -> DefaultLimitOffsetPage[BoardGroupMemoryRead]:
await _require_group_access(session, group_id=group_id, ctx=ctx, write=False)
statement = (
select(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id) == group_id)
BoardGroupMemory.objects.filter_by(board_group_id=group_id)
# Old/invalid rows (empty/whitespace-only content) can exist; exclude them to
# satisfy the NonEmptyStr response schema.
.where(func.length(func.trim(col(BoardGroupMemory.content))) > 0)
.filter(func.length(func.trim(col(BoardGroupMemory.content))) > 0)
)
if is_chat is not None:
statement = statement.where(col(BoardGroupMemory.is_chat) == is_chat)
statement = statement.filter(col(BoardGroupMemory.is_chat) == is_chat)
statement = statement.order_by(col(BoardGroupMemory.created_at).desc())
return await paginate(session, statement)
return await paginate(session, statement.statement)
@group_router.get("/stream")
@@ -297,7 +297,7 @@ async def create_board_group_memory(
) -> BoardGroupMemory:
group = await _require_group_access(session, group_id=group_id, ctx=ctx, write=True)
user = await session.get(User, ctx.member.user_id)
user = await User.objects.by_id(ctx.member.user_id).first(session)
actor = ActorContext(actor_type="user", user=user)
tags = set(payload.tags or [])
is_chat = "chat" in tags
@@ -332,19 +332,18 @@ async def list_board_group_memory_for_board(
) -> DefaultLimitOffsetPage[BoardGroupMemoryRead]:
group_id = board.board_group_id
if group_id is None:
statement = select(BoardGroupMemory).where(col(BoardGroupMemory.id).is_(None))
return await paginate(session, statement)
return await paginate(session, BoardGroupMemory.objects.by_ids([]).statement)
statement = (
select(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id) == group_id)
queryset = (
BoardGroupMemory.objects.filter_by(board_group_id=group_id)
# Old/invalid rows (empty/whitespace-only content) can exist; exclude them to
# satisfy the NonEmptyStr response schema.
.where(func.length(func.trim(col(BoardGroupMemory.content))) > 0)
.filter(func.length(func.trim(col(BoardGroupMemory.content))) > 0)
)
if is_chat is not None:
statement = statement.where(col(BoardGroupMemory.is_chat) == is_chat)
statement = statement.order_by(col(BoardGroupMemory.created_at).desc())
return await paginate(session, statement)
queryset = queryset.filter(col(BoardGroupMemory.is_chat) == is_chat)
queryset = queryset.order_by(col(BoardGroupMemory.created_at).desc())
return await paginate(session, queryset.statement)
@board_router.get("/stream")
@@ -396,7 +395,7 @@ async def create_board_group_memory_for_board(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Board is not in a board group",
)
group = await session.get(BoardGroup, group_id)
group = await BoardGroup.objects.by_id(group_id).first(session)
if group is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)

View File

@@ -56,7 +56,7 @@ async def _require_group_access(
member: OrganizationMember,
write: bool,
) -> BoardGroup:
group = await session.get(BoardGroup, group_id)
group = await BoardGroup.objects.by_id(group_id).first(session)
if group is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if group.organization_id != member.organization_id:
@@ -67,9 +67,9 @@ async def _require_group_access(
if not write and member_all_boards_read(member):
return group
board_ids = list(
await session.exec(select(Board.id).where(col(Board.board_group_id) == group_id))
)
board_ids = [
board.id for board in await Board.objects.filter_by(board_group_id=group_id).all(session)
]
if not board_ids:
if is_org_admin(member):
return group
@@ -153,7 +153,7 @@ async def apply_board_group_heartbeat(
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> BoardGroupHeartbeatApplyResult:
group = await session.get(BoardGroup, group_id)
group = await BoardGroup.objects.by_id(group_id).first(session)
if group is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -181,11 +181,11 @@ async def apply_board_group_heartbeat(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if not agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
board = await session.get(Board, agent.board_id)
board = await Board.objects.by_id(agent.board_id).first(session)
if board is None or board.board_group_id != group_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
boards = list(await session.exec(select(Board).where(col(Board.board_group_id) == group_id)))
boards = await Board.objects.filter_by(board_group_id=group_id).all(session)
board_by_id = {board.id: board for board in boards}
board_ids = list(board_by_id.keys())
if not board_ids:
@@ -196,7 +196,7 @@ async def apply_board_group_heartbeat(
failed_agent_ids=[],
)
agents = list(await session.exec(select(Agent).where(col(Agent.board_id).in_(board_ids))))
agents = await Agent.objects.by_field_in("board_id", board_ids).all(session)
if not payload.include_board_leads:
agents = [agent for agent in agents if not agent.is_board_lead]
@@ -232,7 +232,7 @@ async def apply_board_group_heartbeat(
failed_agent_ids: list[UUID] = []
gateway_ids = list(agents_by_gateway_id.keys())
gateways = list(await session.exec(select(Gateway).where(col(Gateway.id).in_(gateway_ids))))
gateways = await Gateway.objects.by_ids(gateway_ids).all(session)
gateway_by_id = {gateway.id: gateway for gateway in gateways}
for gateway_id, gateway_agents in agents_by_gateway_id.items():
gateway = gateway_by_id.get(gateway_id)

View File

@@ -8,7 +8,7 @@ from uuid import UUID
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import func
from sqlmodel import col, select
from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession
from sse_starlette.sse import EventSourceResponse
@@ -58,7 +58,7 @@ def _serialize_memory(memory: BoardMemory) -> dict[str, object]:
async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None:
if board.gateway_id is None:
return None
gateway = await session.get(Gateway, board.gateway_id)
gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway is None or not gateway.url:
return None
return GatewayClientConfig(url=gateway.url, token=gateway.token)
@@ -83,17 +83,17 @@ async def _fetch_memory_events(
is_chat: bool | None = None,
) -> list[BoardMemory]:
statement = (
select(BoardMemory).where(col(BoardMemory.board_id) == board_id)
BoardMemory.objects.filter_by(board_id=board_id)
# Old/invalid rows (empty/whitespace-only content) can exist; exclude them to
# satisfy the NonEmptyStr response schema.
.where(func.length(func.trim(col(BoardMemory.content))) > 0)
.filter(func.length(func.trim(col(BoardMemory.content))) > 0)
)
if is_chat is not None:
statement = statement.where(col(BoardMemory.is_chat) == is_chat)
statement = statement.where(col(BoardMemory.created_at) >= since).order_by(
statement = statement.filter(col(BoardMemory.is_chat) == is_chat)
statement = statement.filter(col(BoardMemory.created_at) >= since).order_by(
col(BoardMemory.created_at)
)
return list(await session.exec(statement))
return await statement.all(session)
async def _notify_chat_targets(
@@ -114,8 +114,7 @@ async def _notify_chat_targets(
# Special-case control commands to reach all board agents.
# These are intended to be parsed verbatim by agent runtimes.
if command in {"/pause", "/resume"}:
statement = select(Agent).where(col(Agent.board_id) == board.id)
pause_targets: list[Agent] = list(await session.exec(statement))
pause_targets: list[Agent] = await Agent.objects.filter_by(board_id=board.id).all(session)
for agent in pause_targets:
if actor.actor_type == "agent" and actor.agent and agent.id == actor.agent.id:
continue
@@ -134,9 +133,8 @@ async def _notify_chat_targets(
return
mentions = extract_mentions(memory.content)
statement = select(Agent).where(col(Agent.board_id) == board.id)
targets: dict[str, Agent] = {}
for agent in await session.exec(statement):
for agent in await Agent.objects.filter_by(board_id=board.id).all(session):
if agent.is_board_lead:
targets[str(agent.id)] = agent
continue
@@ -188,15 +186,15 @@ async def list_board_memory(
actor: ActorContext = Depends(require_admin_or_agent),
) -> DefaultLimitOffsetPage[BoardMemoryRead]:
statement = (
select(BoardMemory).where(col(BoardMemory.board_id) == board.id)
BoardMemory.objects.filter_by(board_id=board.id)
# Old/invalid rows (empty/whitespace-only content) can exist; exclude them to
# satisfy the NonEmptyStr response schema.
.where(func.length(func.trim(col(BoardMemory.content))) > 0)
.filter(func.length(func.trim(col(BoardMemory.content))) > 0)
)
if is_chat is not None:
statement = statement.where(col(BoardMemory.is_chat) == is_chat)
statement = statement.filter(col(BoardMemory.is_chat) == is_chat)
statement = statement.order_by(col(BoardMemory.created_at).desc())
return await paginate(session, statement)
return await paginate(session, statement.statement)
@router.get("/stream")

View File

@@ -6,7 +6,7 @@ from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import ValidationError
from sqlmodel import col, select
from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import (
@@ -50,7 +50,7 @@ async def _gateway_config(
) -> tuple[Gateway, GatewayClientConfig]:
if not board.gateway_id:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
gateway = await session.get(Gateway, board.gateway_id)
gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway is None or not gateway.url or not gateway.main_session_key:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
return gateway, GatewayClientConfig(url=gateway.url, token=gateway.token)
@@ -80,12 +80,10 @@ async def _ensure_lead_agent(
identity_profile: dict[str, str] | None = None,
) -> Agent:
existing = (
await session.exec(
select(Agent)
.where(Agent.board_id == board.id)
.where(col(Agent.is_board_lead).is_(True))
)
).first()
await Agent.objects.filter_by(board_id=board.id)
.filter(col(Agent.is_board_lead).is_(True))
.first(session)
)
if existing:
desired_name = agent_name or _lead_agent_name(board)
if existing.name != desired_name:
@@ -147,12 +145,10 @@ async def get_onboarding(
session: AsyncSession = Depends(get_session),
) -> BoardOnboardingSession:
onboarding = (
await session.exec(
select(BoardOnboardingSession)
.where(BoardOnboardingSession.board_id == board.id)
.order_by(col(BoardOnboardingSession.created_at).desc())
)
).first()
await BoardOnboardingSession.objects.filter_by(board_id=board.id)
.order_by(col(BoardOnboardingSession.updated_at).desc())
.first(session)
)
if onboarding is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return onboarding
@@ -165,12 +161,10 @@ async def start_onboarding(
session: AsyncSession = Depends(get_session),
) -> BoardOnboardingSession:
onboarding = (
await session.exec(
select(BoardOnboardingSession)
.where(BoardOnboardingSession.board_id == board.id)
.where(BoardOnboardingSession.status == "active")
)
).first()
await BoardOnboardingSession.objects.filter_by(board_id=board.id)
.filter(col(BoardOnboardingSession.status) == "active")
.first(session)
)
if onboarding:
return onboarding
@@ -248,12 +242,10 @@ async def answer_onboarding(
session: AsyncSession = Depends(get_session),
) -> BoardOnboardingSession:
onboarding = (
await session.exec(
select(BoardOnboardingSession)
.where(BoardOnboardingSession.board_id == board.id)
.order_by(col(BoardOnboardingSession.created_at).desc())
)
).first()
await BoardOnboardingSession.objects.filter_by(board_id=board.id)
.order_by(col(BoardOnboardingSession.updated_at).desc())
.first(session)
)
if onboarding is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -295,18 +287,16 @@ async def agent_onboarding_update(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if board.gateway_id:
gateway = await session.get(Gateway, board.gateway_id)
gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway and gateway.main_session_key and agent.openclaw_session_id:
if agent.openclaw_session_id != gateway.main_session_key:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
onboarding = (
await session.exec(
select(BoardOnboardingSession)
.where(BoardOnboardingSession.board_id == board.id)
.order_by(col(BoardOnboardingSession.created_at).desc())
)
).first()
await BoardOnboardingSession.objects.filter_by(board_id=board.id)
.order_by(col(BoardOnboardingSession.updated_at).desc())
.first(session)
)
if onboarding is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if onboarding.status == "confirmed":
@@ -351,12 +341,10 @@ async def confirm_onboarding(
auth: AuthContext = Depends(require_admin_auth),
) -> Board:
onboarding = (
await session.exec(
select(BoardOnboardingSession)
.where(BoardOnboardingSession.board_id == board.id)
.order_by(col(BoardOnboardingSession.created_at).desc())
)
).first()
await BoardOnboardingSession.objects.filter_by(board_id=board.id)
.order_by(col(BoardOnboardingSession.updated_at).desc())
.first(session)
)
if onboarding is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)

View File

@@ -163,7 +163,7 @@ async def _board_gateway(
) -> tuple[Gateway | None, GatewayClientConfig | None]:
if not board.gateway_id:
return None, None
config = await session.get(Gateway, board.gateway_id)
config = await Gateway.objects.by_id(board.gateway_id).first(session)
if config is None:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -292,7 +292,7 @@ async def delete_board(
session: AsyncSession = Depends(get_session),
board: Board = Depends(get_board_for_user_write),
) -> OkResponse:
agents = list(await session.exec(select(Agent).where(Agent.board_id == board.id)))
agents = await Agent.objects.filter_by(board_id=board.id).all(session)
task_ids = list(await session.exec(select(Task.id).where(Task.board_id == board.id)))
config, client_config = await _board_gateway(session, board)

View File

@@ -59,7 +59,7 @@ async def require_org_member(
member = await ensure_member_for_user(session, auth.user)
if member is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
organization = await session.get(Organization, member.organization_id)
organization = await Organization.objects.by_id(member.organization_id).first(session)
if organization is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return OrganizationContext(organization=organization, member=member)
@@ -77,7 +77,7 @@ async def get_board_or_404(
board_id: str,
session: AsyncSession = Depends(get_session),
) -> Board:
board = await session.get(Board, board_id)
board = await Board.objects.by_id(board_id).first(session)
if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return board
@@ -88,7 +88,7 @@ async def get_board_for_actor_read(
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> Board:
board = await session.get(Board, board_id)
board = await Board.objects.by_id(board_id).first(session)
if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if actor.actor_type == "agent":
@@ -106,7 +106,7 @@ async def get_board_for_actor_write(
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> Board:
board = await session.get(Board, board_id)
board = await Board.objects.by_id(board_id).first(session)
if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if actor.actor_type == "agent":
@@ -124,7 +124,7 @@ async def get_board_for_user_read(
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
) -> Board:
board = await session.get(Board, board_id)
board = await Board.objects.by_id(board_id).first(session)
if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if auth.user is None:
@@ -138,7 +138,7 @@ async def get_board_for_user_write(
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
) -> Board:
board = await session.get(Board, board_id)
board = await Board.objects.by_id(board_id).first(session)
if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if auth.user is None:
@@ -152,7 +152,7 @@ async def get_task_or_404(
board: Board = Depends(get_board_for_actor_read),
session: AsyncSession = Depends(get_session),
) -> Task:
task = await session.get(Task, task_id)
task = await Task.objects.by_id(task_id).first(session)
if task is None or task.board_id != board.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return task

View File

@@ -56,7 +56,7 @@ async def _resolve_gateway(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="board_id or gateway_url is required",
)
board = await session.get(Board, board_id)
board = await Board.objects.by_id(board_id).first(session)
if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found")
if isinstance(user, object) and user is not None:
@@ -66,7 +66,7 @@ async def _resolve_gateway(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Board gateway_id is required",
)
gateway = await session.get(Gateway, board.gateway_id)
gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway is None:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -216,7 +216,7 @@ async def get_gateway_session(
sessions_list = list(sessions.get("sessions") or [])
else:
sessions_list = list(sessions or [])
if main_session and not any(session.get("key") == main_session for session in sessions_list):
if main_session and not any(item.get("key") == main_session for item in sessions_list):
try:
await ensure_session(main_session, config=config, label="Main Agent")
refreshed = await openclaw_call("sessions.list", config=config)

View File

@@ -2,12 +2,11 @@ from __future__ import annotations
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from sqlmodel import col, select
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import require_org_admin
from app.api.queryset import api_qs
from app.core.agent_tokens import generate_agent_token, hash_agent_token
from app.core.auth import AuthContext, get_auth_context
from app.core.time import utcnow
@@ -43,14 +42,14 @@ async def _require_gateway(
gateway_id: UUID,
organization_id: UUID,
) -> Gateway:
return await (
api_qs(Gateway)
.filter(
col(Gateway.id) == gateway_id,
col(Gateway.organization_id) == organization_id,
)
.first_or_404(session, detail="Gateway not found")
gateway = (
await Gateway.objects.by_id(gateway_id)
.filter(col(Gateway.organization_id) == organization_id)
.first(session)
)
if gateway is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found")
return gateway
async def _find_main_agent(
@@ -60,26 +59,22 @@ async def _find_main_agent(
previous_session_key: str | None = None,
) -> Agent | None:
if gateway.main_session_key:
agent = (
await session.exec(
select(Agent).where(Agent.openclaw_session_id == gateway.main_session_key)
)
).first()
agent = await Agent.objects.filter_by(openclaw_session_id=gateway.main_session_key).first(
session
)
if agent:
return agent
if previous_session_key:
agent = (
await session.exec(
select(Agent).where(Agent.openclaw_session_id == previous_session_key)
)
).first()
agent = await Agent.objects.filter_by(openclaw_session_id=previous_session_key).first(
session
)
if agent:
return agent
names = {_main_agent_name(gateway)}
if previous_name:
names.add(f"{previous_name} Main")
for name in names:
agent = (await session.exec(select(Agent).where(Agent.name == name))).first()
agent = await Agent.objects.filter_by(name=name).first(session)
if agent:
return agent
return None
@@ -153,8 +148,7 @@ async def list_gateways(
ctx: OrganizationContext = Depends(require_org_admin),
) -> DefaultLimitOffsetPage[GatewayRead]:
statement = (
api_qs(Gateway)
.filter(col(Gateway.organization_id) == ctx.organization.id)
Gateway.objects.filter_by(organization_id=ctx.organization.id)
.order_by(col(Gateway.created_at).desc())
.statement
)

View File

@@ -10,7 +10,6 @@ from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import require_org_admin, require_org_member
from app.api.queryset import api_qs
from app.core.auth import AuthContext, get_auth_context
from app.core.time import utcnow
from app.db import crud
@@ -81,14 +80,10 @@ async def _require_org_member(
organization_id: UUID,
member_id: UUID,
) -> OrganizationMember:
return await (
api_qs(OrganizationMember)
.filter(
col(OrganizationMember.id) == member_id,
col(OrganizationMember.organization_id) == organization_id,
)
.first_or_404(session)
)
member = await OrganizationMember.objects.by_id(member_id).first(session)
if member is None or member.organization_id != organization_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return member
async def _require_org_invite(
@@ -97,14 +92,10 @@ async def _require_org_invite(
organization_id: UUID,
invite_id: UUID,
) -> OrganizationInvite:
return await (
api_qs(OrganizationInvite)
.filter(
col(OrganizationInvite.id) == invite_id,
col(OrganizationInvite.organization_id) == organization_id,
)
.first_or_404(session)
)
invite = await OrganizationInvite.objects.by_id(invite_id).first(session)
if invite is None or invite.organization_id != organization_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return invite
@router.post("", response_model=OrganizationRead)
@@ -157,7 +148,7 @@ async def list_my_organizations(
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
await get_active_membership(session, auth.user)
db_user = await session.get(User, auth.user.id)
db_user = await User.objects.by_id(auth.user.id).first(session)
active_id = db_user.active_organization_id if db_user else auth.user.active_organization_id
statement = (
@@ -189,7 +180,7 @@ async def set_active_org(
member = await set_active_organization(
session, user=auth.user, organization_id=payload.organization_id
)
organization = await session.get(Organization, member.organization_id)
organization = await Organization.objects.by_id(member.organization_id).first(session)
if organization is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return OrganizationRead.model_validate(organization, from_attributes=True)
@@ -293,14 +284,10 @@ async def get_my_membership(
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_member),
) -> OrganizationMemberRead:
user = await session.get(User, ctx.member.user_id)
access_rows = list(
await session.exec(
select(OrganizationBoardAccess).where(
col(OrganizationBoardAccess.organization_member_id) == ctx.member.id
)
)
)
user = await User.objects.by_id(ctx.member.user_id).first(session)
access_rows = await OrganizationBoardAccess.objects.filter_by(
organization_member_id=ctx.member.id
).all(session)
model = _member_to_read(ctx.member, user)
model.board_access = [
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows
@@ -342,14 +329,10 @@ async def get_org_member(
)
if not is_org_admin(ctx.member) and member.user_id != ctx.member.user_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
user = await session.get(User, member.user_id)
access_rows = list(
await session.exec(
select(OrganizationBoardAccess).where(
col(OrganizationBoardAccess.organization_member_id) == member.id
)
)
)
user = await User.objects.by_id(member.user_id).first(session)
access_rows = await OrganizationBoardAccess.objects.filter_by(
organization_member_id=member.id
).all(session)
model = _member_to_read(member, user)
model.board_access = [
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows
@@ -374,7 +357,7 @@ async def update_org_member(
updates["role"] = normalize_role(updates["role"])
updates["updated_at"] = utcnow()
member = await crud.patch(session, member, updates)
user = await session.get(User, member.user_id)
user = await User.objects.by_id(member.user_id).first(session)
return _member_to_read(member, user)
@@ -393,20 +376,19 @@ async def update_member_access(
board_ids = {entry.board_id for entry in payload.board_access}
if board_ids:
valid_board_ids = set(
await session.exec(
select(Board.id)
.where(col(Board.id).in_(board_ids))
.where(col(Board.organization_id) == ctx.organization.id)
)
)
valid_board_ids = {
board.id
for board in await Board.objects.filter_by(organization_id=ctx.organization.id)
.filter(col(Board.id).in_(board_ids))
.all(session)
}
if valid_board_ids != board_ids:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
await apply_member_access_update(session, member=member, update=payload)
await session.commit()
await session.refresh(member)
user = await session.get(User, member.user_id)
user = await User.objects.by_id(member.user_id).first(session)
return _member_to_read(member, user)
@@ -416,9 +398,11 @@ async def remove_org_member(
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
) -> OkResponse:
member = await session.get(OrganizationMember, member_id)
if member is None or member.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
member = await _require_org_member(
session,
organization_id=ctx.organization.id,
member_id=member_id,
)
if member.user_id == ctx.member.user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@@ -430,15 +414,12 @@ async def remove_org_member(
detail="Only owners can remove owners",
)
if member.role == "owner":
owner_ids = list(
await session.exec(
select(OrganizationMember.id).where(
col(OrganizationMember.organization_id) == ctx.organization.id,
col(OrganizationMember.role) == "owner",
)
)
owners = (
await OrganizationMember.objects.filter_by(organization_id=ctx.organization.id)
.filter(col(OrganizationMember.role) == "owner")
.all(session)
)
if len(owner_ids) <= 1:
if len(owners) <= 1:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail="Organization must have at least one owner",
@@ -451,17 +432,22 @@ async def remove_org_member(
),
)
user = await session.get(User, member.user_id)
user = await User.objects.by_id(member.user_id).first(session)
if user is not None and user.active_organization_id == ctx.organization.id:
fallback_org_id = (
await session.exec(
select(OrganizationMember.organization_id)
.where(col(OrganizationMember.user_id) == user.id)
.where(col(OrganizationMember.organization_id) != ctx.organization.id)
.order_by(col(OrganizationMember.created_at).asc())
fallback_membership = (
await OrganizationMember.objects.filter(
col(OrganizationMember.user_id) == user.id,
col(OrganizationMember.organization_id) != ctx.organization.id,
)
.order_by(col(OrganizationMember.created_at).asc())
.first(session)
)
if isinstance(fallback_membership, UUID):
user.active_organization_id = fallback_membership
else:
user.active_organization_id = (
fallback_membership.organization_id if fallback_membership is not None else None
)
).first()
user.active_organization_id = fallback_org_id
session.add(user)
await crud.delete(session, member)
@@ -474,8 +460,7 @@ async def list_org_invites(
ctx: OrganizationContext = Depends(require_org_admin),
) -> DefaultLimitOffsetPage[OrganizationInviteRead]:
statement = (
api_qs(OrganizationInvite)
.filter(col(OrganizationInvite.organization_id) == ctx.organization.id)
OrganizationInvite.objects.filter_by(organization_id=ctx.organization.id)
.filter(col(OrganizationInvite.accepted_at).is_(None))
.order_by(col(OrganizationInvite.created_at).desc())
.statement
@@ -522,13 +507,12 @@ async def create_org_invite(
board_ids = {entry.board_id for entry in payload.board_access}
if board_ids:
valid_board_ids = set(
await session.exec(
select(Board.id)
.where(col(Board.id).in_(board_ids))
.where(col(Board.organization_id) == ctx.organization.id)
)
)
valid_board_ids = {
board.id
for board in await Board.objects.filter_by(organization_id=ctx.organization.id)
.filter(col(Board.id).in_(board_ids))
.all(session)
}
if valid_board_ids != board_ids:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
await apply_invite_board_access(session, invite=invite, entries=payload.board_access)
@@ -566,13 +550,10 @@ async def accept_org_invite(
) -> OrganizationMemberRead:
if auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
invite = (
await session.exec(
select(OrganizationInvite)
.where(col(OrganizationInvite.token) == payload.token)
.where(col(OrganizationInvite.accepted_at).is_(None))
)
).first()
invite = await OrganizationInvite.objects.filter(
col(OrganizationInvite.token) == payload.token,
col(OrganizationInvite.accepted_at).is_(None),
).first(session)
if invite is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if invite.invited_email and auth.user.email:
@@ -597,5 +578,5 @@ async def accept_org_invite(
await session.commit()
member = existing
user = await session.get(User, member.user_id)
user = await User.objects.by_id(member.user_id).first(session)
return _member_to_read(member, user)

View File

@@ -243,7 +243,7 @@ def _serialize_comment(event: ActivityEvent) -> dict[str, object]:
async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None:
if not board.gateway_id:
return None
gateway = await session.get(Gateway, board.gateway_id)
gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway is None or not gateway.url:
return None
return GatewayClientConfig(url=gateway.url, token=gateway.token)
@@ -331,12 +331,10 @@ async def _notify_lead_on_task_create(
task: Task,
) -> None:
lead = (
await session.exec(
select(Agent)
.where(Agent.board_id == board.id)
.where(col(Agent.is_board_lead).is_(True))
)
).first()
await Agent.objects.filter_by(board_id=board.id)
.filter(col(Agent.is_board_lead).is_(True))
.first(session)
)
if lead is None or not lead.openclaw_session_id:
return
config = await _gateway_config(session, board)
@@ -390,12 +388,10 @@ async def _notify_lead_on_task_unassigned(
task: Task,
) -> None:
lead = (
await session.exec(
select(Agent)
.where(Agent.board_id == board.id)
.where(col(Agent.is_board_lead).is_(True))
)
).first()
await Agent.objects.filter_by(board_id=board.id)
.filter(col(Agent.is_board_lead).is_(True))
.first(session)
)
if lead is None or not lead.openclaw_session_id:
return
config = await _gateway_config(session, board)
@@ -635,7 +631,7 @@ async def create_task(
await session.commit()
await _notify_lead_on_task_create(session=session, board=board, task=task)
if task.assigned_agent_id:
assigned_agent = await session.get(Agent, task.assigned_agent_id)
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
if assigned_agent:
await _notify_agent_on_task_assign(
session=session,
@@ -670,7 +666,7 @@ async def update_task(
)
board_id = task.board_id
if actor.actor_type == "user" and actor.user is not None:
board = await session.get(Board, board_id)
board = await Board.objects.by_id(board_id).first(session)
if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
await require_board_access(session, user=actor.user, board=board, write=True)
@@ -740,7 +736,7 @@ async def update_task(
if "assigned_agent_id" in updates:
assigned_id = updates["assigned_agent_id"]
if assigned_id:
agent = await session.get(Agent, assigned_id)
agent = await Agent.objects.by_id(assigned_id).first(session)
if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if agent.is_board_lead:
@@ -796,9 +792,13 @@ async def update_task(
await session.refresh(task)
if task.assigned_agent_id and task.assigned_agent_id != previous_assigned:
assigned_agent = await session.get(Agent, task.assigned_agent_id)
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
if assigned_agent:
board = await session.get(Board, task.board_id) if task.board_id else None
board = (
await Board.objects.by_id(task.board_id).first(session)
if task.board_id
else None
)
if board:
await _notify_agent_on_task_assign(
session=session,
@@ -879,7 +879,7 @@ async def update_task(
task.in_progress_at = utcnow()
if "assigned_agent_id" in updates and updates["assigned_agent_id"]:
agent = await session.get(Agent, updates["assigned_agent_id"])
agent = await Agent.objects.by_id(updates["assigned_agent_id"]).first(session)
if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if agent.board_id and task.board_id and agent.board_id != task.board_id:
@@ -941,7 +941,9 @@ async def update_task(
if task.status == "inbox" and task.assigned_agent_id is None:
if previous_status != "inbox" or previous_assigned is not None:
board = await session.get(Board, task.board_id) if task.board_id else None
board = (
await Board.objects.by_id(task.board_id).first(session) if task.board_id else None
)
if board:
await _notify_lead_on_task_unassigned(
session=session,
@@ -953,9 +955,13 @@ async def update_task(
# Don't notify the actor about their own assignment.
pass
else:
assigned_agent = await session.get(Agent, task.assigned_agent_id)
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
if assigned_agent:
board = await session.get(Board, task.board_id) if task.board_id else None
board = (
await Board.objects.by_id(task.board_id).first(session)
if task.board_id
else None
)
if board:
await _notify_agent_on_task_assign(
session=session,
@@ -985,7 +991,7 @@ async def delete_task(
) -> OkResponse:
if task.board_id is None:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
board = await session.get(Board, task.board_id)
board = await Board.objects.by_id(task.board_id).first(session)
if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if auth.user is None:
@@ -1032,7 +1038,7 @@ async def create_task_comment(
if task.board_id is None:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
if actor.actor_type == "user" and actor.user is not None:
board = await session.get(Board, task.board_id)
board = await Board.objects.by_id(task.board_id).first(session)
if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
await require_board_access(session, user=actor.user, board=board, write=True)
@@ -1059,18 +1065,17 @@ async def create_task_comment(
mention_names = extract_mentions(payload.message)
targets: dict[UUID, Agent] = {}
if mention_names and task.board_id:
statement = select(Agent).where(col(Agent.board_id) == task.board_id)
for agent in await session.exec(statement):
for agent in await Agent.objects.filter_by(board_id=task.board_id).all(session):
if matches_agent_mention(agent, mention_names):
targets[agent.id] = agent
if not mention_names and task.assigned_agent_id:
assigned_agent = await session.get(Agent, task.assigned_agent_id)
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
if assigned_agent:
targets[assigned_agent.id] = assigned_agent
if actor.actor_type == "agent" and actor.agent:
targets.pop(actor.agent.id, None)
if targets:
board = await session.get(Board, task.board_id) if task.board_id else None
board = await Board.objects.by_id(task.board_id).first(session) if task.board_id else None
config = await _gateway_config(session, board) if board else None
if board and config:
snippet = payload.message.strip()

View File

@@ -27,7 +27,8 @@ def _lookup_statement(model: type[ModelT], lookup: Mapping[str, Any]) -> SelectO
async def get_by_id(session: AsyncSession, model: type[ModelT], obj_id: Any) -> ModelT | None:
return await session.get(model, obj_id)
stmt = _lookup_statement(model, {"id": obj_id}).limit(1)
return (await session.exec(stmt)).first()
async def get(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT:

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Generic, TypeVar
from sqlalchemy import false
from sqlmodel import SQLModel, col
from app.db.queryset import QuerySet, qs
ModelT = TypeVar("ModelT", bound=SQLModel)
@dataclass(frozen=True)
class ModelManager(Generic[ModelT]):
model: type[ModelT]
id_field: str = "id"
def all(self) -> QuerySet[ModelT]:
return qs(self.model)
def none(self) -> QuerySet[ModelT]:
return qs(self.model).filter(false())
def filter(self, *criteria: Any) -> QuerySet[ModelT]:
return self.all().filter(*criteria)
def where(self, *criteria: Any) -> QuerySet[ModelT]:
return self.filter(*criteria)
def filter_by(self, **kwargs: Any) -> QuerySet[ModelT]:
queryset = self.all()
for field_name, value in kwargs.items():
queryset = queryset.filter(col(getattr(self.model, field_name)) == value)
return queryset
def by_id(self, obj_id: Any) -> QuerySet[ModelT]:
return self.by_field(self.id_field, obj_id)
def by_ids(self, obj_ids: list[Any] | tuple[Any, ...] | set[Any]) -> QuerySet[ModelT]:
return self.by_field_in(self.id_field, obj_ids)
def by_field(self, field_name: str, value: Any) -> QuerySet[ModelT]:
return self.filter(col(getattr(self.model, field_name)) == value)
def by_field_in(
self,
field_name: str,
values: list[Any] | tuple[Any, ...] | set[Any],
) -> QuerySet[ModelT]:
seq = tuple(values)
if not seq:
return self.none()
return self.filter(col(getattr(self.model, field_name)).in_(seq))
class ManagerDescriptor(Generic[ModelT]):
def __get__(self, instance: object, owner: type[ModelT]) -> ModelManager[ModelT]:
return ModelManager(owner)

View File

@@ -17,6 +17,13 @@ class QuerySet(Generic[ModelT]):
def filter(self, *criteria: Any) -> QuerySet[ModelT]:
return replace(self, statement=self.statement.where(*criteria))
def where(self, *criteria: Any) -> QuerySet[ModelT]:
return self.filter(*criteria)
def filter_by(self, **kwargs: Any) -> QuerySet[ModelT]:
statement = self.statement.filter_by(**kwargs)
return replace(self, statement=statement)
def order_by(self, *ordering: Any) -> QuerySet[ModelT]:
return replace(self, statement=self.statement.order_by(*ordering))

View File

@@ -3,12 +3,13 @@ from __future__ import annotations
from datetime import datetime
from uuid import UUID, uuid4
from sqlmodel import Field, SQLModel
from sqlmodel import Field
from app.core.time import utcnow
from app.models.base import QueryModel
class ActivityEvent(SQLModel, table=True):
class ActivityEvent(QueryModel, table=True):
__tablename__ = "activity_events"
id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -5,12 +5,13 @@ from typing import Any
from uuid import UUID, uuid4
from sqlalchemy import JSON, Column, Text
from sqlmodel import Field, SQLModel
from sqlmodel import Field
from app.core.time import utcnow
from app.models.base import QueryModel
class Agent(SQLModel, table=True):
class Agent(QueryModel, table=True):
__tablename__ = "agents"
id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -4,12 +4,13 @@ from datetime import datetime
from uuid import UUID, uuid4
from sqlalchemy import JSON, Column
from sqlmodel import Field, SQLModel
from sqlmodel import Field
from app.core.time import utcnow
from app.models.base import QueryModel
class Approval(SQLModel, table=True):
class Approval(QueryModel, table=True):
__tablename__ = "approvals"
id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -0,0 +1,11 @@
from __future__ import annotations
from typing import ClassVar, Self
from sqlmodel import SQLModel
from app.db.query_manager import ManagerDescriptor
class QueryModel(SQLModel, table=False):
objects: ClassVar[ManagerDescriptor[Self]] = ManagerDescriptor()

View File

@@ -4,12 +4,13 @@ from datetime import datetime
from uuid import UUID, uuid4
from sqlalchemy import JSON, Column
from sqlmodel import Field, SQLModel
from sqlmodel import Field
from app.core.time import utcnow
from app.models.base import QueryModel
class BoardGroupMemory(SQLModel, table=True):
class BoardGroupMemory(QueryModel, table=True):
__tablename__ = "board_group_memory"
id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -4,12 +4,13 @@ from datetime import datetime
from uuid import UUID, uuid4
from sqlalchemy import JSON, Column
from sqlmodel import Field, SQLModel
from sqlmodel import Field
from app.core.time import utcnow
from app.models.base import QueryModel
class BoardMemory(SQLModel, table=True):
class BoardMemory(QueryModel, table=True):
__tablename__ = "board_memory"
id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -4,12 +4,13 @@ from datetime import datetime
from uuid import UUID, uuid4
from sqlalchemy import JSON, Column
from sqlmodel import Field, SQLModel
from sqlmodel import Field
from app.core.time import utcnow
from app.models.base import QueryModel
class BoardOnboardingSession(SQLModel, table=True):
class BoardOnboardingSession(QueryModel, table=True):
__tablename__ = "board_onboarding_sessions"
id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -3,12 +3,13 @@ from __future__ import annotations
from datetime import datetime
from uuid import UUID, uuid4
from sqlmodel import Field, SQLModel
from sqlmodel import Field
from app.core.time import utcnow
from app.models.base import QueryModel
class Gateway(SQLModel, table=True):
class Gateway(QueryModel, table=True):
__tablename__ = "gateways"
id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -4,12 +4,13 @@ from datetime import datetime
from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint
from sqlmodel import Field, SQLModel
from sqlmodel import Field
from app.core.time import utcnow
from app.models.base import QueryModel
class OrganizationBoardAccess(SQLModel, table=True):
class OrganizationBoardAccess(QueryModel, table=True):
__tablename__ = "organization_board_access"
__table_args__ = (
UniqueConstraint(

View File

@@ -4,12 +4,13 @@ from datetime import datetime
from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint
from sqlmodel import Field, SQLModel
from sqlmodel import Field
from app.core.time import utcnow
from app.models.base import QueryModel
class OrganizationInviteBoardAccess(SQLModel, table=True):
class OrganizationInviteBoardAccess(QueryModel, table=True):
__tablename__ = "organization_invite_board_access"
__table_args__ = (
UniqueConstraint(

View File

@@ -4,12 +4,13 @@ from datetime import datetime
from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint
from sqlmodel import Field, SQLModel
from sqlmodel import Field
from app.core.time import utcnow
from app.models.base import QueryModel
class OrganizationInvite(SQLModel, table=True):
class OrganizationInvite(QueryModel, table=True):
__tablename__ = "organization_invites"
__table_args__ = (UniqueConstraint("token", name="uq_org_invites_token"),)

View File

@@ -4,12 +4,13 @@ from datetime import datetime
from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint
from sqlmodel import Field, SQLModel
from sqlmodel import Field
from app.core.time import utcnow
from app.models.base import QueryModel
class OrganizationMember(SQLModel, table=True):
class OrganizationMember(QueryModel, table=True):
__tablename__ = "organization_members"
__table_args__ = (
UniqueConstraint(

View File

@@ -4,12 +4,13 @@ from datetime import datetime
from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint
from sqlmodel import Field, SQLModel
from sqlmodel import Field
from app.core.time import utcnow
from app.models.base import QueryModel
class Organization(SQLModel, table=True):
class Organization(QueryModel, table=True):
__tablename__ = "organizations"
__table_args__ = (UniqueConstraint("name", name="uq_organizations_name"),)

View File

@@ -3,12 +3,13 @@ from __future__ import annotations
from datetime import datetime
from uuid import UUID, uuid4
from sqlmodel import Field, SQLModel
from sqlmodel import Field
from app.core.time import utcnow
from app.models.base import QueryModel
class TaskFingerprint(SQLModel, table=True):
class TaskFingerprint(QueryModel, table=True):
__tablename__ = "task_fingerprints"
id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from sqlmodel import SQLModel
from app.models.base import QueryModel
class TenantScoped(SQLModel, table=False):
class TenantScoped(QueryModel, table=False):
pass

View File

@@ -2,10 +2,12 @@ from __future__ import annotations
from uuid import UUID, uuid4
from sqlmodel import Field, SQLModel
from sqlmodel import Field
from app.models.base import QueryModel
class User(SQLModel, table=True):
class User(QueryModel, table=True):
__tablename__ = "users"
id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -1 +0,0 @@
from __future__ import annotations

View File

@@ -1,50 +0,0 @@
from __future__ import annotations
from uuid import UUID
from sqlmodel import col
from app.db.queryset import QuerySet, qs
from app.models.organization_board_access import OrganizationBoardAccess
from app.models.organization_invites import OrganizationInvite
from app.models.organization_members import OrganizationMember
from app.models.organizations import Organization
def organization_by_name(name: str) -> QuerySet[Organization]:
return qs(Organization).filter(col(Organization.name) == name)
def member_by_user_and_org(*, user_id: UUID, organization_id: UUID) -> QuerySet[OrganizationMember]:
return qs(OrganizationMember).filter(
col(OrganizationMember.organization_id) == organization_id,
col(OrganizationMember.user_id) == user_id,
)
def first_membership_for_user(user_id: UUID) -> QuerySet[OrganizationMember]:
return (
qs(OrganizationMember)
.filter(col(OrganizationMember.user_id) == user_id)
.order_by(col(OrganizationMember.created_at).asc())
)
def pending_invite_by_email(email: str) -> QuerySet[OrganizationInvite]:
return (
qs(OrganizationInvite)
.filter(col(OrganizationInvite.accepted_at).is_(None))
.filter(col(OrganizationInvite.invited_email) == email)
.order_by(col(OrganizationInvite.created_at).asc())
)
def board_access_for_member_and_board(
*,
organization_member_id: UUID,
board_id: UUID,
) -> QuerySet[OrganizationBoardAccess]:
return qs(OrganizationBoardAccess).filter(
col(OrganizationBoardAccess.organization_member_id) == organization_member_id,
col(OrganizationBoardAccess.board_id) == board_id,
)

View File

@@ -42,7 +42,7 @@ async def build_group_snapshot(
include_done: bool = False,
per_board_task_limit: int = 5,
) -> BoardGroupSnapshot:
statement = select(Board).where(col(Board.board_group_id) == group.id)
statement = Board.objects.filter_by(board_group_id=group.id).statement
if exclude_board_id is not None:
statement = statement.where(col(Board.id) != exclude_board_id)
boards = list(await session.exec(statement.order_by(func.lower(col(Board.name)).asc())))
@@ -146,7 +146,7 @@ async def build_board_group_snapshot(
) -> BoardGroupSnapshot:
if not board.board_group_id:
return BoardGroupSnapshot(group=None, boards=[])
group = await session.get(BoardGroup, board.board_group_id)
group = await BoardGroup.objects.by_id(board.board_group_id).first(session)
if group is None:
return BoardGroupSnapshot(group=None, boards=[])
return await build_group_snapshot(

View File

@@ -97,9 +97,9 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
board_read = BoardRead.model_validate(board, from_attributes=True)
tasks = list(
await session.exec(
select(Task).where(col(Task.board_id) == board.id).order_by(col(Task.created_at).desc())
)
await Task.objects.filter_by(board_id=board.id)
.order_by(col(Task.created_at).desc())
.all(session)
)
task_ids = [task.id for task in tasks]
@@ -114,12 +114,10 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
)
main_session_keys = await _gateway_main_session_keys(session)
agents = list(
await session.exec(
select(Agent)
.where(col(Agent.board_id) == board.id)
.order_by(col(Agent.created_at).desc())
)
agents = (
await Agent.objects.filter_by(board_id=board.id)
.order_by(col(Agent.created_at).desc())
.all(session)
)
agent_reads = [_agent_to_read(agent, main_session_keys) for agent in agents]
agent_name_by_id = {agent.id: agent.name for agent in agents}
@@ -134,13 +132,11 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
).one()
)
approvals = list(
await session.exec(
select(Approval)
.where(col(Approval.board_id) == board.id)
.order_by(col(Approval.created_at).desc())
.limit(200)
)
approvals = (
await Approval.objects.filter_by(board_id=board.id)
.order_by(col(Approval.created_at).desc())
.limit(200)
.all(session)
)
approval_reads = [_approval_to_read(approval) for approval in approvals]
@@ -173,17 +169,15 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
for task in tasks
]
chat_messages = list(
await session.exec(
select(BoardMemory)
.where(col(BoardMemory.board_id) == board.id)
.where(col(BoardMemory.is_chat).is_(True))
# Old/invalid rows (empty/whitespace-only content) can exist; exclude them to
# satisfy the NonEmptyStr response schema.
.where(func.length(func.trim(col(BoardMemory.content))) > 0)
.order_by(col(BoardMemory.created_at).desc())
.limit(200)
)
chat_messages = (
await BoardMemory.objects.filter_by(board_id=board.id)
.filter(col(BoardMemory.is_chat).is_(True))
# Old/invalid rows (empty/whitespace-only content) can exist; exclude them to
# satisfy the NonEmptyStr response schema.
.filter(func.length(func.trim(col(BoardMemory.content))) > 0)
.order_by(col(BoardMemory.created_at).desc())
.limit(200)
.all(session)
)
chat_messages.sort(key=lambda item: item.created_at)
chat_reads = [_memory_to_read(memory) for memory in chat_messages]

View File

@@ -19,7 +19,6 @@ from app.models.organization_invites import OrganizationInvite
from app.models.organization_members import OrganizationMember
from app.models.organizations import Organization
from app.models.users import User
from app.queries import organizations as org_queries
from app.schemas.organizations import OrganizationBoardAccessSpec, OrganizationMemberAccessUpdate
DEFAULT_ORG_NAME = "Personal"
@@ -38,7 +37,7 @@ def is_org_admin(member: OrganizationMember) -> bool:
async def get_default_org(session: AsyncSession) -> Organization | None:
return await org_queries.organization_by_name(DEFAULT_ORG_NAME).first(session)
return await Organization.objects.filter_by(name=DEFAULT_ORG_NAME).first(session)
async def ensure_default_org(session: AsyncSession) -> Organization:
@@ -58,14 +57,18 @@ async def get_member(
user_id: UUID,
organization_id: UUID,
) -> OrganizationMember | None:
return await org_queries.member_by_user_and_org(
return await OrganizationMember.objects.filter_by(
user_id=user_id,
organization_id=organization_id,
).first(session)
async def get_first_membership(session: AsyncSession, user_id: UUID) -> OrganizationMember | None:
return await org_queries.first_membership_for_user(user_id).first(session)
return (
await OrganizationMember.objects.filter_by(user_id=user_id)
.order_by(col(OrganizationMember.created_at).asc())
.first(session)
)
async def set_active_organization(
@@ -88,7 +91,7 @@ async def get_active_membership(
session: AsyncSession,
user: User,
) -> OrganizationMember | None:
db_user = await session.get(User, user.id)
db_user = await User.objects.by_id(user.id).first(session)
if db_user is None:
db_user = user
if db_user.active_organization_id:
@@ -119,7 +122,14 @@ async def _find_pending_invite(
session: AsyncSession,
email: str,
) -> OrganizationInvite | None:
return await org_queries.pending_invite_by_email(email).first(session)
return (
await OrganizationInvite.objects.filter(
col(OrganizationInvite.accepted_at).is_(None),
col(OrganizationInvite.invited_email) == email,
)
.order_by(col(OrganizationInvite.created_at).asc())
.first(session)
)
async def accept_invite(
@@ -230,7 +240,7 @@ async def has_board_access(
else:
if member_all_boards_read(member):
return True
access = await org_queries.board_access_for_member_and_board(
access = await OrganizationBoardAccess.objects.filter_by(
organization_member_id=member.id,
board_id=board.id,
).first(session)

View File

@@ -328,7 +328,7 @@ async def sync_gateway_templates(
result.errors.append(GatewayTemplatesSyncError(message=str(exc)))
return result
boards = list(await session.exec(select(Board).where(col(Board.gateway_id) == gateway.id)))
boards = await Board.objects.filter_by(gateway_id=gateway.id).all(session)
boards_by_id = {board.id: board for board in boards}
if board_id is not None:
board = boards_by_id.get(board_id)
@@ -345,12 +345,10 @@ async def sync_gateway_templates(
paused_board_ids = await _paused_board_ids(session, list(boards_by_id.keys()))
if boards_by_id:
agents = list(
await session.exec(
select(Agent)
.where(col(Agent.board_id).in_(list(boards_by_id.keys())))
.order_by(col(Agent.created_at).asc())
)
agents = await (
Agent.objects.by_field_in("board_id", list(boards_by_id.keys()))
.order_by(col(Agent.created_at).asc())
.all(session)
)
else:
agents = []
@@ -471,10 +469,10 @@ async def sync_gateway_templates(
if include_main:
main_agent = (
await session.exec(
select(Agent).where(col(Agent.openclaw_session_id) == gateway.main_session_key)
)
).first()
await Agent.objects.all()
.filter(col(Agent.openclaw_session_id) == gateway.main_session_key)
.first(session)
)
if main_agent is None:
result.errors.append(
GatewayTemplatesSyncError(