feat: add validation for minimum length on various fields and update type definitions
This commit is contained in:
@@ -4,15 +4,18 @@ from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlmodel import Session, select
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.agent_tokens import generate_agent_token, hash_agent_token
|
||||
from app.core.auth import AuthContext, get_auth_context
|
||||
from app.core.time import utcnow
|
||||
from app.db.session import get_session
|
||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
|
||||
from app.models.agents import Agent
|
||||
from app.models.gateways import Gateway
|
||||
from app.schemas.common import OkResponse
|
||||
from app.schemas.gateways import GatewayCreate, GatewayRead, GatewayUpdate
|
||||
from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_main_agent
|
||||
|
||||
@@ -235,21 +238,25 @@ def _main_agent_name(gateway: Gateway) -> str:
|
||||
return f"{gateway.name} Main"
|
||||
|
||||
|
||||
def _find_main_agent(
|
||||
session: Session,
|
||||
async def _find_main_agent(
|
||||
session: AsyncSession,
|
||||
gateway: Gateway,
|
||||
previous_name: str | None = None,
|
||||
previous_session_key: str | None = None,
|
||||
) -> Agent | None:
|
||||
if gateway.main_session_key:
|
||||
agent = session.exec(
|
||||
select(Agent).where(Agent.openclaw_session_id == gateway.main_session_key)
|
||||
agent = (
|
||||
await session.exec(
|
||||
select(Agent).where(Agent.openclaw_session_id == gateway.main_session_key)
|
||||
)
|
||||
).first()
|
||||
if agent:
|
||||
return agent
|
||||
if previous_session_key:
|
||||
agent = session.exec(
|
||||
select(Agent).where(Agent.openclaw_session_id == previous_session_key)
|
||||
agent = (
|
||||
await session.exec(
|
||||
select(Agent).where(Agent.openclaw_session_id == previous_session_key)
|
||||
)
|
||||
).first()
|
||||
if agent:
|
||||
return agent
|
||||
@@ -257,14 +264,14 @@ def _find_main_agent(
|
||||
if previous_name:
|
||||
names.add(f"{previous_name} Main")
|
||||
for name in names:
|
||||
agent = session.exec(select(Agent).where(Agent.name == name)).first()
|
||||
agent = (await session.exec(select(Agent).where(Agent.name == name))).first()
|
||||
if agent:
|
||||
return agent
|
||||
return None
|
||||
|
||||
|
||||
async def _ensure_main_agent(
|
||||
session: Session,
|
||||
session: AsyncSession,
|
||||
gateway: Gateway,
|
||||
auth: AuthContext,
|
||||
*,
|
||||
@@ -274,7 +281,7 @@ async def _ensure_main_agent(
|
||||
) -> Agent | None:
|
||||
if not gateway.url or not gateway.main_session_key:
|
||||
return None
|
||||
agent = _find_main_agent(session, gateway, previous_name, previous_session_key)
|
||||
agent = await _find_main_agent(session, gateway, previous_name, previous_session_key)
|
||||
if agent is None:
|
||||
agent = Agent(
|
||||
name=_main_agent_name(gateway),
|
||||
@@ -294,14 +301,14 @@ async def _ensure_main_agent(
|
||||
agent.openclaw_session_id = gateway.main_session_key
|
||||
raw_token = generate_agent_token()
|
||||
agent.agent_token_hash = hash_agent_token(raw_token)
|
||||
agent.provision_requested_at = datetime.utcnow()
|
||||
agent.provision_requested_at = utcnow()
|
||||
agent.provision_action = action
|
||||
agent.updated_at = datetime.utcnow()
|
||||
agent.updated_at = utcnow()
|
||||
if agent.heartbeat_config is None:
|
||||
agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy()
|
||||
session.add(agent)
|
||||
session.commit()
|
||||
session.refresh(agent)
|
||||
await session.commit()
|
||||
await session.refresh(agent)
|
||||
try:
|
||||
await provision_main_agent(agent, gateway, raw_token, auth.user, action=action)
|
||||
await ensure_session(
|
||||
@@ -356,26 +363,24 @@ async def _send_skyll_disable_message(gateway: Gateway) -> None:
|
||||
|
||||
|
||||
@router.get("", response_model=list[GatewayRead])
|
||||
def list_gateways(
|
||||
session: Session = Depends(get_session),
|
||||
async def list_gateways(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> list[Gateway]:
|
||||
return list(session.exec(select(Gateway)))
|
||||
return list(await session.exec(select(Gateway)))
|
||||
|
||||
|
||||
@router.post("", response_model=GatewayRead)
|
||||
async def create_gateway(
|
||||
payload: GatewayCreate,
|
||||
session: Session = Depends(get_session),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> Gateway:
|
||||
data = payload.model_dump()
|
||||
if data.get("token") == "":
|
||||
data["token"] = None
|
||||
gateway = Gateway.model_validate(data)
|
||||
session.add(gateway)
|
||||
session.commit()
|
||||
session.refresh(gateway)
|
||||
await session.commit()
|
||||
await session.refresh(gateway)
|
||||
await _ensure_main_agent(session, gateway, auth, action="provision")
|
||||
if gateway.skyll_enabled:
|
||||
try:
|
||||
@@ -386,12 +391,12 @@ async def create_gateway(
|
||||
|
||||
|
||||
@router.get("/{gateway_id}", response_model=GatewayRead)
|
||||
def get_gateway(
|
||||
async def get_gateway(
|
||||
gateway_id: UUID,
|
||||
session: Session = Depends(get_session),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> Gateway:
|
||||
gateway = session.get(Gateway, gateway_id)
|
||||
gateway = await session.get(Gateway, gateway_id)
|
||||
if gateway is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found")
|
||||
return gateway
|
||||
@@ -401,23 +406,21 @@ def get_gateway(
|
||||
async def update_gateway(
|
||||
gateway_id: UUID,
|
||||
payload: GatewayUpdate,
|
||||
session: Session = Depends(get_session),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> Gateway:
|
||||
gateway = session.get(Gateway, gateway_id)
|
||||
gateway = await session.get(Gateway, gateway_id)
|
||||
if gateway is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found")
|
||||
previous_name = gateway.name
|
||||
previous_session_key = gateway.main_session_key
|
||||
previous_skyll_enabled = gateway.skyll_enabled
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
if updates.get("token") == "":
|
||||
updates["token"] = None
|
||||
for key, value in updates.items():
|
||||
setattr(gateway, key, value)
|
||||
session.add(gateway)
|
||||
session.commit()
|
||||
session.refresh(gateway)
|
||||
await session.commit()
|
||||
await session.refresh(gateway)
|
||||
await _ensure_main_agent(
|
||||
session,
|
||||
gateway,
|
||||
@@ -439,15 +442,15 @@ async def update_gateway(
|
||||
return gateway
|
||||
|
||||
|
||||
@router.delete("/{gateway_id}")
|
||||
def delete_gateway(
|
||||
@router.delete("/{gateway_id}", response_model=OkResponse)
|
||||
async def delete_gateway(
|
||||
gateway_id: UUID,
|
||||
session: Session = Depends(get_session),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> dict[str, bool]:
|
||||
gateway = session.get(Gateway, gateway_id)
|
||||
) -> OkResponse:
|
||||
gateway = await session.get(Gateway, gateway_id)
|
||||
if gateway is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found")
|
||||
session.delete(gateway)
|
||||
session.commit()
|
||||
return {"ok": True}
|
||||
await session.delete(gateway)
|
||||
await session.commit()
|
||||
return OkResponse()
|
||||
|
||||
Reference in New Issue
Block a user