refactor: update migration paths and improve database operation handling
This commit is contained in:
@@ -2,14 +2,16 @@ from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.api.deps import require_org_admin
|
||||
from app.api.queryset import api_qs
|
||||
from app.core.agent_tokens import generate_agent_token, hash_agent_token
|
||||
from app.core.auth import AuthContext, get_auth_context
|
||||
from app.core.time import utcnow
|
||||
from app.db import crud
|
||||
from app.db.pagination import paginate
|
||||
from app.db.session import get_session
|
||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||
@@ -35,6 +37,22 @@ def _main_agent_name(gateway: Gateway) -> str:
|
||||
return f"{gateway.name} Main"
|
||||
|
||||
|
||||
async def _require_gateway(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
gateway_id: UUID,
|
||||
organization_id: UUID,
|
||||
) -> Gateway:
|
||||
return await (
|
||||
api_qs(Gateway)
|
||||
.filter(
|
||||
col(Gateway.id) == gateway_id,
|
||||
col(Gateway.organization_id) == organization_id,
|
||||
)
|
||||
.first_or_404(session, detail="Gateway not found")
|
||||
)
|
||||
|
||||
|
||||
async def _find_main_agent(
|
||||
session: AsyncSession,
|
||||
gateway: Gateway,
|
||||
@@ -135,9 +153,10 @@ async def list_gateways(
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
) -> DefaultLimitOffsetPage[GatewayRead]:
|
||||
statement = (
|
||||
select(Gateway)
|
||||
.where(col(Gateway.organization_id) == ctx.organization.id)
|
||||
api_qs(Gateway)
|
||||
.filter(col(Gateway.organization_id) == ctx.organization.id)
|
||||
.order_by(col(Gateway.created_at).desc())
|
||||
.statement
|
||||
)
|
||||
return await paginate(session, statement)
|
||||
|
||||
@@ -151,10 +170,7 @@ async def create_gateway(
|
||||
) -> Gateway:
|
||||
data = payload.model_dump()
|
||||
data["organization_id"] = ctx.organization.id
|
||||
gateway = Gateway.model_validate(data)
|
||||
session.add(gateway)
|
||||
await session.commit()
|
||||
await session.refresh(gateway)
|
||||
gateway = await crud.create(session, Gateway, **data)
|
||||
await _ensure_main_agent(session, gateway, auth, action="provision")
|
||||
return gateway
|
||||
|
||||
@@ -165,10 +181,11 @@ async def get_gateway(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
) -> Gateway:
|
||||
gateway = await session.get(Gateway, gateway_id)
|
||||
if gateway is None or gateway.organization_id != ctx.organization.id:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found")
|
||||
return gateway
|
||||
return await _require_gateway(
|
||||
session,
|
||||
gateway_id=gateway_id,
|
||||
organization_id=ctx.organization.id,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{gateway_id}", response_model=GatewayRead)
|
||||
@@ -179,17 +196,15 @@ async def update_gateway(
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
) -> Gateway:
|
||||
gateway = await session.get(Gateway, gateway_id)
|
||||
if gateway is None or gateway.organization_id != ctx.organization.id:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found")
|
||||
gateway = await _require_gateway(
|
||||
session,
|
||||
gateway_id=gateway_id,
|
||||
organization_id=ctx.organization.id,
|
||||
)
|
||||
previous_name = gateway.name
|
||||
previous_session_key = gateway.main_session_key
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
for key, value in updates.items():
|
||||
setattr(gateway, key, value)
|
||||
session.add(gateway)
|
||||
await session.commit()
|
||||
await session.refresh(gateway)
|
||||
await crud.patch(session, gateway, updates)
|
||||
await _ensure_main_agent(
|
||||
session,
|
||||
gateway,
|
||||
@@ -213,9 +228,11 @@ async def sync_gateway_templates(
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
) -> GatewayTemplatesSyncResult:
|
||||
gateway = await session.get(Gateway, gateway_id)
|
||||
if gateway is None or gateway.organization_id != ctx.organization.id:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found")
|
||||
gateway = await _require_gateway(
|
||||
session,
|
||||
gateway_id=gateway_id,
|
||||
organization_id=ctx.organization.id,
|
||||
)
|
||||
return await sync_gateway_templates_service(
|
||||
session,
|
||||
gateway,
|
||||
@@ -234,9 +251,10 @@ async def delete_gateway(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
) -> OkResponse:
|
||||
gateway = await session.get(Gateway, gateway_id)
|
||||
if gateway is None or gateway.organization_id != ctx.organization.id:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found")
|
||||
await session.delete(gateway)
|
||||
await session.commit()
|
||||
gateway = await _require_gateway(
|
||||
session,
|
||||
gateway_id=gateway_id,
|
||||
organization_id=ctx.organization.id,
|
||||
)
|
||||
await crud.delete(session, gateway)
|
||||
return OkResponse()
|
||||
|
||||
Reference in New Issue
Block a user