refactor: replace SQLModel with QueryModel in various models and update query methods
This commit is contained in:
@@ -2,12 +2,11 @@ from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlmodel import col, select
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlmodel import col
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.api.deps import require_org_admin
|
||||
from app.api.queryset import api_qs
|
||||
from app.core.agent_tokens import generate_agent_token, hash_agent_token
|
||||
from app.core.auth import AuthContext, get_auth_context
|
||||
from app.core.time import utcnow
|
||||
@@ -43,14 +42,14 @@ async def _require_gateway(
|
||||
gateway_id: UUID,
|
||||
organization_id: UUID,
|
||||
) -> Gateway:
|
||||
return await (
|
||||
api_qs(Gateway)
|
||||
.filter(
|
||||
col(Gateway.id) == gateway_id,
|
||||
col(Gateway.organization_id) == organization_id,
|
||||
)
|
||||
.first_or_404(session, detail="Gateway not found")
|
||||
gateway = (
|
||||
await Gateway.objects.by_id(gateway_id)
|
||||
.filter(col(Gateway.organization_id) == organization_id)
|
||||
.first(session)
|
||||
)
|
||||
if gateway is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found")
|
||||
return gateway
|
||||
|
||||
|
||||
async def _find_main_agent(
|
||||
@@ -60,26 +59,22 @@ async def _find_main_agent(
|
||||
previous_session_key: str | None = None,
|
||||
) -> Agent | None:
|
||||
if gateway.main_session_key:
|
||||
agent = (
|
||||
await session.exec(
|
||||
select(Agent).where(Agent.openclaw_session_id == gateway.main_session_key)
|
||||
)
|
||||
).first()
|
||||
agent = await Agent.objects.filter_by(openclaw_session_id=gateway.main_session_key).first(
|
||||
session
|
||||
)
|
||||
if agent:
|
||||
return agent
|
||||
if previous_session_key:
|
||||
agent = (
|
||||
await session.exec(
|
||||
select(Agent).where(Agent.openclaw_session_id == previous_session_key)
|
||||
)
|
||||
).first()
|
||||
agent = await Agent.objects.filter_by(openclaw_session_id=previous_session_key).first(
|
||||
session
|
||||
)
|
||||
if agent:
|
||||
return agent
|
||||
names = {_main_agent_name(gateway)}
|
||||
if previous_name:
|
||||
names.add(f"{previous_name} Main")
|
||||
for name in names:
|
||||
agent = (await session.exec(select(Agent).where(Agent.name == name))).first()
|
||||
agent = await Agent.objects.filter_by(name=name).first(session)
|
||||
if agent:
|
||||
return agent
|
||||
return None
|
||||
@@ -153,8 +148,7 @@ async def list_gateways(
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
) -> DefaultLimitOffsetPage[GatewayRead]:
|
||||
statement = (
|
||||
api_qs(Gateway)
|
||||
.filter(col(Gateway.organization_id) == ctx.organization.id)
|
||||
Gateway.objects.filter_by(organization_id=ctx.organization.id)
|
||||
.order_by(col(Gateway.created_at).desc())
|
||||
.statement
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user