refactor: update migration paths and improve database operation handling

This commit is contained in:
Abhimanyu Saharan
2026-02-09 00:51:26 +05:30
parent 8c4bcca603
commit f6bcd1ca5f
43 changed files with 1175 additions and 1445 deletions

View File

@@ -21,6 +21,7 @@ from app.core.auth import AuthContext, get_auth_context
from app.core.time import utcnow
from app.db.pagination import paginate
from app.db.session import async_session_maker, get_session
from app.db.sqlmodel_exec import exec_dml
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
from app.models.activity_events import ActivityEvent
@@ -973,7 +974,8 @@ async def delete_agent(
agent_id=None,
)
now = datetime.now()
await session.execute(
await exec_dml(
session,
update(Task)
.where(col(Task.assigned_agent_id) == agent.id)
.where(col(Task.status) == "in_progress")
@@ -982,19 +984,21 @@ async def delete_agent(
status="inbox",
in_progress_at=None,
updated_at=now,
)
),
)
await session.execute(
await exec_dml(
session,
update(Task)
.where(col(Task.assigned_agent_id) == agent.id)
.where(col(Task.status) != "in_progress")
.values(
assigned_agent_id=None,
updated_at=now,
)
),
)
await session.execute(
update(ActivityEvent).where(col(ActivityEvent.agent_id) == agent.id).values(agent_id=None)
await exec_dml(
session,
update(ActivityEvent).where(col(ActivityEvent.agent_id) == agent.id).values(agent_id=None),
)
await session.delete(agent)
await session.commit()

View File

@@ -14,6 +14,7 @@ from app.core.time import utcnow
from app.db import crud
from app.db.pagination import paginate
from app.db.session import get_session
from app.db.sqlmodel_exec import exec_dml
from app.models.agents import Agent
from app.models.board_group_memory import BoardGroupMemory
from app.models.board_groups import BoardGroup
@@ -262,10 +263,8 @@ async def update_board_group(
updates = payload.model_dump(exclude_unset=True)
if "slug" in updates and updates["slug"] is not None and not updates["slug"].strip():
updates["slug"] = _slugify(updates.get("name") or group.name)
for key, value in updates.items():
setattr(group, key, value)
group.updated_at = utcnow()
return await crud.save(session, group)
updates["updated_at"] = utcnow()
return await crud.patch(session, group, updates)
@router.delete("/{group_id}", response_model=OkResponse)
@@ -277,12 +276,14 @@ async def delete_board_group(
await _require_group_access(session, group_id=group_id, member=ctx.member, write=True)
# Boards reference groups, so clear the FK first to keep deletes simple.
await session.execute(
update(Board).where(col(Board.board_group_id) == group_id).values(board_group_id=None)
await exec_dml(
session,
update(Board).where(col(Board.board_group_id) == group_id).values(board_group_id=None),
)
await session.execute(
delete(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id) == group_id)
await exec_dml(
session,
delete(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id) == group_id),
)
await session.execute(delete(BoardGroup).where(col(BoardGroup.id) == group_id))
await exec_dml(session, delete(BoardGroup).where(col(BoardGroup.id) == group_id))
await session.commit()
return OkResponse()

View File

@@ -19,6 +19,7 @@ from app.core.time import utcnow
from app.db import crud
from app.db.pagination import paginate
from app.db.session import get_session
from app.db.sqlmodel_exec import exec_dml
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import (
OpenClawGatewayError,
@@ -140,8 +141,7 @@ async def _apply_board_update(
updates["board_group_id"],
organization_id=board.organization_id,
)
for key, value in updates.items():
setattr(board, key, value)
crud.apply_updates(board, updates)
if updates.get("board_type") == "goal":
# Validate only when explicitly switching to goal boards.
if not board.objective or not board.success_metrics:
@@ -307,36 +307,43 @@ async def delete_board(
) from exc
if task_ids:
await session.execute(delete(ActivityEvent).where(col(ActivityEvent.task_id).in_(task_ids)))
await session.execute(delete(TaskDependency).where(col(TaskDependency.board_id) == board.id))
await session.execute(delete(TaskFingerprint).where(col(TaskFingerprint.board_id) == board.id))
await exec_dml(
session, delete(ActivityEvent).where(col(ActivityEvent.task_id).in_(task_ids))
)
await exec_dml(session, delete(TaskDependency).where(col(TaskDependency.board_id) == board.id))
await exec_dml(
session, delete(TaskFingerprint).where(col(TaskFingerprint.board_id) == board.id)
)
# Approvals can reference tasks and agents, so delete before both.
await session.execute(delete(Approval).where(col(Approval.board_id) == board.id))
await exec_dml(session, delete(Approval).where(col(Approval.board_id) == board.id))
await session.execute(delete(BoardMemory).where(col(BoardMemory.board_id) == board.id))
await session.execute(
delete(BoardOnboardingSession).where(col(BoardOnboardingSession.board_id) == board.id)
await exec_dml(session, delete(BoardMemory).where(col(BoardMemory.board_id) == board.id))
await exec_dml(
session,
delete(BoardOnboardingSession).where(col(BoardOnboardingSession.board_id) == board.id),
)
await session.execute(
delete(OrganizationBoardAccess).where(col(OrganizationBoardAccess.board_id) == board.id)
await exec_dml(
session,
delete(OrganizationBoardAccess).where(col(OrganizationBoardAccess.board_id) == board.id),
)
await session.execute(
await exec_dml(
session,
delete(OrganizationInviteBoardAccess).where(
col(OrganizationInviteBoardAccess.board_id) == board.id
)
),
)
# Tasks reference agents (assigned_agent_id) and have dependents (fingerprints/dependencies), so
# delete tasks before agents.
await session.execute(delete(Task).where(col(Task.board_id) == board.id))
await exec_dml(session, delete(Task).where(col(Task.board_id) == board.id))
if agents:
agent_ids = [agent.id for agent in agents]
await session.execute(
delete(ActivityEvent).where(col(ActivityEvent.agent_id).in_(agent_ids))
await exec_dml(
session, delete(ActivityEvent).where(col(ActivityEvent.agent_id).in_(agent_ids))
)
await session.execute(delete(Agent).where(col(Agent.id).in_(agent_ids)))
await exec_dml(session, delete(Agent).where(col(Agent.id).in_(agent_ids)))
await session.delete(board)
await session.commit()
return OkResponse()

View File

@@ -2,14 +2,16 @@ from __future__ import annotations
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi import APIRouter, Depends, Query
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import require_org_admin
from app.api.queryset import api_qs
from app.core.agent_tokens import generate_agent_token, hash_agent_token
from app.core.auth import AuthContext, get_auth_context
from app.core.time import utcnow
from app.db import crud
from app.db.pagination import paginate
from app.db.session import get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
@@ -35,6 +37,22 @@ def _main_agent_name(gateway: Gateway) -> str:
return f"{gateway.name} Main"
async def _require_gateway(
session: AsyncSession,
*,
gateway_id: UUID,
organization_id: UUID,
) -> Gateway:
return await (
api_qs(Gateway)
.filter(
col(Gateway.id) == gateway_id,
col(Gateway.organization_id) == organization_id,
)
.first_or_404(session, detail="Gateway not found")
)
async def _find_main_agent(
session: AsyncSession,
gateway: Gateway,
@@ -135,9 +153,10 @@ async def list_gateways(
ctx: OrganizationContext = Depends(require_org_admin),
) -> DefaultLimitOffsetPage[GatewayRead]:
statement = (
select(Gateway)
.where(col(Gateway.organization_id) == ctx.organization.id)
api_qs(Gateway)
.filter(col(Gateway.organization_id) == ctx.organization.id)
.order_by(col(Gateway.created_at).desc())
.statement
)
return await paginate(session, statement)
@@ -151,10 +170,7 @@ async def create_gateway(
) -> Gateway:
data = payload.model_dump()
data["organization_id"] = ctx.organization.id
gateway = Gateway.model_validate(data)
session.add(gateway)
await session.commit()
await session.refresh(gateway)
gateway = await crud.create(session, Gateway, **data)
await _ensure_main_agent(session, gateway, auth, action="provision")
return gateway
@@ -165,10 +181,11 @@ async def get_gateway(
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
) -> Gateway:
gateway = await session.get(Gateway, gateway_id)
if gateway is None or gateway.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found")
return gateway
return await _require_gateway(
session,
gateway_id=gateway_id,
organization_id=ctx.organization.id,
)
@router.patch("/{gateway_id}", response_model=GatewayRead)
@@ -179,17 +196,15 @@ async def update_gateway(
auth: AuthContext = Depends(get_auth_context),
ctx: OrganizationContext = Depends(require_org_admin),
) -> Gateway:
gateway = await session.get(Gateway, gateway_id)
if gateway is None or gateway.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found")
gateway = await _require_gateway(
session,
gateway_id=gateway_id,
organization_id=ctx.organization.id,
)
previous_name = gateway.name
previous_session_key = gateway.main_session_key
updates = payload.model_dump(exclude_unset=True)
for key, value in updates.items():
setattr(gateway, key, value)
session.add(gateway)
await session.commit()
await session.refresh(gateway)
await crud.patch(session, gateway, updates)
await _ensure_main_agent(
session,
gateway,
@@ -213,9 +228,11 @@ async def sync_gateway_templates(
auth: AuthContext = Depends(get_auth_context),
ctx: OrganizationContext = Depends(require_org_admin),
) -> GatewayTemplatesSyncResult:
gateway = await session.get(Gateway, gateway_id)
if gateway is None or gateway.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found")
gateway = await _require_gateway(
session,
gateway_id=gateway_id,
organization_id=ctx.organization.id,
)
return await sync_gateway_templates_service(
session,
gateway,
@@ -234,9 +251,10 @@ async def delete_gateway(
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
) -> OkResponse:
gateway = await session.get(Gateway, gateway_id)
if gateway is None or gateway.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found")
await session.delete(gateway)
await session.commit()
gateway = await _require_gateway(
session,
gateway_id=gateway_id,
organization_id=ctx.organization.id,
)
await crud.delete(session, gateway)
return OkResponse()

View File

@@ -10,10 +10,13 @@ from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import require_org_admin, require_org_member
from app.api.queryset import api_qs
from app.core.auth import AuthContext, get_auth_context
from app.core.time import utcnow
from app.db import crud
from app.db.pagination import paginate
from app.db.session import get_session
from app.db.sqlmodel_exec import exec_dml
from app.models.activity_events import ActivityEvent
from app.models.agents import Agent
from app.models.approvals import Approval
@@ -72,6 +75,38 @@ def _member_to_read(member: OrganizationMember, user: User | None) -> Organizati
return model
async def _require_org_member(
session: AsyncSession,
*,
organization_id: UUID,
member_id: UUID,
) -> OrganizationMember:
return await (
api_qs(OrganizationMember)
.filter(
col(OrganizationMember.id) == member_id,
col(OrganizationMember.organization_id) == organization_id,
)
.first_or_404(session)
)
async def _require_org_invite(
session: AsyncSession,
*,
organization_id: UUID,
invite_id: UUID,
) -> OrganizationInvite:
return await (
api_qs(OrganizationInvite)
.filter(
col(OrganizationInvite.id) == invite_id,
col(OrganizationInvite.organization_id) == organization_id,
)
.first_or_404(session)
)
@router.post("", response_model=OrganizationRead)
async def create_organization(
payload: OrganizationCreate,
@@ -188,55 +223,67 @@ async def delete_my_org(
)
group_ids = select(BoardGroup.id).where(col(BoardGroup.organization_id) == org_id)
await session.execute(delete(ActivityEvent).where(col(ActivityEvent.task_id).in_(task_ids)))
await session.execute(delete(ActivityEvent).where(col(ActivityEvent.agent_id).in_(agent_ids)))
await session.execute(delete(TaskDependency).where(col(TaskDependency.board_id).in_(board_ids)))
await session.execute(
delete(TaskFingerprint).where(col(TaskFingerprint.board_id).in_(board_ids))
await exec_dml(session, delete(ActivityEvent).where(col(ActivityEvent.task_id).in_(task_ids)))
await exec_dml(session, delete(ActivityEvent).where(col(ActivityEvent.agent_id).in_(agent_ids)))
await exec_dml(
session, delete(TaskDependency).where(col(TaskDependency.board_id).in_(board_ids))
)
await session.execute(delete(Approval).where(col(Approval.board_id).in_(board_ids)))
await session.execute(delete(BoardMemory).where(col(BoardMemory.board_id).in_(board_ids)))
await session.execute(
delete(BoardOnboardingSession).where(col(BoardOnboardingSession.board_id).in_(board_ids))
await exec_dml(
session,
delete(TaskFingerprint).where(col(TaskFingerprint.board_id).in_(board_ids)),
)
await session.execute(
delete(OrganizationBoardAccess).where(col(OrganizationBoardAccess.board_id).in_(board_ids))
await exec_dml(session, delete(Approval).where(col(Approval.board_id).in_(board_ids)))
await exec_dml(session, delete(BoardMemory).where(col(BoardMemory.board_id).in_(board_ids)))
await exec_dml(
session,
delete(BoardOnboardingSession).where(col(BoardOnboardingSession.board_id).in_(board_ids)),
)
await session.execute(
await exec_dml(
session,
delete(OrganizationBoardAccess).where(col(OrganizationBoardAccess.board_id).in_(board_ids)),
)
await exec_dml(
session,
delete(OrganizationInviteBoardAccess).where(
col(OrganizationInviteBoardAccess.board_id).in_(board_ids)
)
),
)
await session.execute(
await exec_dml(
session,
delete(OrganizationBoardAccess).where(
col(OrganizationBoardAccess.organization_member_id).in_(member_ids)
)
),
)
await session.execute(
await exec_dml(
session,
delete(OrganizationInviteBoardAccess).where(
col(OrganizationInviteBoardAccess.organization_invite_id).in_(invite_ids)
)
),
)
await session.execute(delete(Task).where(col(Task.board_id).in_(board_ids)))
await session.execute(delete(Agent).where(col(Agent.board_id).in_(board_ids)))
await session.execute(delete(Board).where(col(Board.organization_id) == org_id))
await session.execute(
delete(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id).in_(group_ids))
await exec_dml(session, delete(Task).where(col(Task.board_id).in_(board_ids)))
await exec_dml(session, delete(Agent).where(col(Agent.board_id).in_(board_ids)))
await exec_dml(session, delete(Board).where(col(Board.organization_id) == org_id))
await exec_dml(
session,
delete(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id).in_(group_ids)),
)
await session.execute(delete(BoardGroup).where(col(BoardGroup.organization_id) == org_id))
await session.execute(delete(Gateway).where(col(Gateway.organization_id) == org_id))
await session.execute(
delete(OrganizationInvite).where(col(OrganizationInvite.organization_id) == org_id)
await exec_dml(session, delete(BoardGroup).where(col(BoardGroup.organization_id) == org_id))
await exec_dml(session, delete(Gateway).where(col(Gateway.organization_id) == org_id))
await exec_dml(
session,
delete(OrganizationInvite).where(col(OrganizationInvite.organization_id) == org_id),
)
await session.execute(
delete(OrganizationMember).where(col(OrganizationMember.organization_id) == org_id)
await exec_dml(
session,
delete(OrganizationMember).where(col(OrganizationMember.organization_id) == org_id),
)
await session.execute(
await exec_dml(
session,
update(User)
.where(col(User.active_organization_id) == org_id)
.values(active_organization_id=None)
.values(active_organization_id=None),
)
await session.execute(delete(Organization).where(col(Organization.id) == org_id))
await exec_dml(session, delete(Organization).where(col(Organization.id) == org_id))
await session.commit()
return OkResponse()
@@ -288,9 +335,11 @@ async def get_org_member(
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_member),
) -> OrganizationMemberRead:
member = await session.get(OrganizationMember, member_id)
if member is None or member.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
member = await _require_org_member(
session,
organization_id=ctx.organization.id,
member_id=member_id,
)
if not is_org_admin(ctx.member) and member.user_id != ctx.member.user_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
user = await session.get(User, member.user_id)
@@ -315,16 +364,16 @@ async def update_org_member(
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
) -> OrganizationMemberRead:
member = await session.get(OrganizationMember, member_id)
if member is None or member.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
member = await _require_org_member(
session,
organization_id=ctx.organization.id,
member_id=member_id,
)
updates = payload.model_dump(exclude_unset=True)
if "role" in updates and updates["role"] is not None:
member.role = normalize_role(updates["role"])
member.updated_at = utcnow()
session.add(member)
await session.commit()
await session.refresh(member)
updates["role"] = normalize_role(updates["role"])
updates["updated_at"] = utcnow()
member = await crud.patch(session, member, updates)
user = await session.get(User, member.user_id)
return _member_to_read(member, user)
@@ -336,9 +385,11 @@ async def update_member_access(
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
) -> OrganizationMemberRead:
member = await session.get(OrganizationMember, member_id)
if member is None or member.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
member = await _require_org_member(
session,
organization_id=ctx.organization.id,
member_id=member_id,
)
board_ids = {entry.board_id for entry in payload.board_access}
if board_ids:
@@ -393,10 +444,11 @@ async def remove_org_member(
detail="Organization must have at least one owner",
)
await session.execute(
await exec_dml(
session,
delete(OrganizationBoardAccess).where(
col(OrganizationBoardAccess.organization_member_id) == member.id
)
),
)
user = await session.get(User, member.user_id)
@@ -412,8 +464,7 @@ async def remove_org_member(
user.active_organization_id = fallback_org_id
session.add(user)
await session.delete(member)
await session.commit()
await crud.delete(session, member)
return OkResponse()
@@ -423,10 +474,11 @@ async def list_org_invites(
ctx: OrganizationContext = Depends(require_org_admin),
) -> DefaultLimitOffsetPage[OrganizationInviteRead]:
statement = (
select(OrganizationInvite)
.where(col(OrganizationInvite.organization_id) == ctx.organization.id)
.where(col(OrganizationInvite.accepted_at).is_(None))
api_qs(OrganizationInvite)
.filter(col(OrganizationInvite.organization_id) == ctx.organization.id)
.filter(col(OrganizationInvite.accepted_at).is_(None))
.order_by(col(OrganizationInvite.created_at).desc())
.statement
)
return await paginate(session, statement)
@@ -491,16 +543,18 @@ async def revoke_org_invite(
session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin),
) -> OrganizationInviteRead:
invite = await session.get(OrganizationInvite, invite_id)
if invite is None or invite.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
await session.execute(
invite = await _require_org_invite(
session,
organization_id=ctx.organization.id,
invite_id=invite_id,
)
await exec_dml(
session,
delete(OrganizationInviteBoardAccess).where(
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
),
)
await session.delete(invite)
await session.commit()
await crud.delete(session, invite)
return OrganizationInviteRead.model_validate(invite, from_attributes=True)

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Generic, TypeVar
from fastapi import HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar
from app.db.queryset import QuerySet, qs
ModelT = TypeVar("ModelT")
@dataclass(frozen=True)
class APIQuerySet(Generic[ModelT]):
queryset: QuerySet[ModelT]
@property
def statement(self) -> SelectOfScalar[ModelT]:
return self.queryset.statement
def filter(self, *criteria: Any) -> APIQuerySet[ModelT]:
return APIQuerySet(self.queryset.filter(*criteria))
def order_by(self, *ordering: Any) -> APIQuerySet[ModelT]:
return APIQuerySet(self.queryset.order_by(*ordering))
def limit(self, value: int) -> APIQuerySet[ModelT]:
return APIQuerySet(self.queryset.limit(value))
def offset(self, value: int) -> APIQuerySet[ModelT]:
return APIQuerySet(self.queryset.offset(value))
async def all(self, session: AsyncSession) -> list[ModelT]:
return await self.queryset.all(session)
async def first(self, session: AsyncSession) -> ModelT | None:
return await self.queryset.first(session)
async def first_or_404(
self,
session: AsyncSession,
*,
detail: str | None = None,
) -> ModelT:
obj = await self.first(session)
if obj is not None:
return obj
if detail is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=detail)
def api_qs(model: type[ModelT]) -> APIQuerySet[ModelT]:
return APIQuerySet(qs(model))

View File

@@ -27,6 +27,7 @@ from app.core.auth import AuthContext
from app.core.time import utcnow
from app.db.pagination import paginate
from app.db.session import async_session_maker, get_session
from app.db.sqlmodel_exec import exec_dml
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
from app.models.activity_events import ActivityEvent
@@ -990,16 +991,17 @@ async def delete_task(
if auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
await require_board_access(session, user=auth.user, board=board, write=True)
await session.execute(delete(ActivityEvent).where(col(ActivityEvent.task_id) == task.id))
await session.execute(delete(TaskFingerprint).where(col(TaskFingerprint.task_id) == task.id))
await session.execute(delete(Approval).where(col(Approval.task_id) == task.id))
await session.execute(
await exec_dml(session, delete(ActivityEvent).where(col(ActivityEvent.task_id) == task.id))
await exec_dml(session, delete(TaskFingerprint).where(col(TaskFingerprint.task_id) == task.id))
await exec_dml(session, delete(Approval).where(col(Approval.task_id) == task.id))
await exec_dml(
session,
delete(TaskDependency).where(
or_(
col(TaskDependency.task_id) == task.id,
col(TaskDependency.depends_on_task_id) == task.id,
)
)
),
)
await session.delete(task)
await session.commit()