feat(models): implement model routing and management pages with dynamic imports
This commit is contained in:
165
backend/app/api/model_registry.py
Normal file
165
backend/app/api/model_registry.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""API routes for gateway model registry and provider-auth management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from app.api.deps import require_org_admin
|
||||
from app.db.session import get_session
|
||||
from app.schemas.common import OkResponse
|
||||
from app.schemas.llm_models import (
|
||||
GatewayModelPullResult,
|
||||
GatewayModelSyncResult,
|
||||
LlmModelCreate,
|
||||
LlmModelRead,
|
||||
LlmModelUpdate,
|
||||
LlmProviderAuthCreate,
|
||||
LlmProviderAuthRead,
|
||||
LlmProviderAuthUpdate,
|
||||
)
|
||||
from app.services.openclaw.model_registry_service import GatewayModelRegistryService
|
||||
from app.services.organizations import OrganizationContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
router = APIRouter(prefix="/model-registry", tags=["model-registry"])
|
||||
|
||||
SESSION_DEP = Depends(get_session)
|
||||
ORG_ADMIN_DEP = Depends(require_org_admin)
|
||||
GATEWAY_ID_QUERY = Query(default=None)
|
||||
|
||||
|
||||
@router.get("/provider-auth", response_model=list[LlmProviderAuthRead])
|
||||
async def list_provider_auth(
|
||||
gateway_id: UUID | None = GATEWAY_ID_QUERY,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> list[LlmProviderAuthRead]:
|
||||
"""List provider auth records for the active organization."""
|
||||
service = GatewayModelRegistryService(session)
|
||||
return await service.list_provider_auth(ctx=ctx, gateway_id=gateway_id)
|
||||
|
||||
|
||||
@router.post("/provider-auth", response_model=LlmProviderAuthRead)
|
||||
async def create_provider_auth(
|
||||
payload: LlmProviderAuthCreate,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> LlmProviderAuthRead:
|
||||
"""Create a provider auth record and sync gateway config."""
|
||||
service = GatewayModelRegistryService(session)
|
||||
return await service.create_provider_auth(payload=payload, ctx=ctx)
|
||||
|
||||
|
||||
@router.patch("/provider-auth/{provider_auth_id}", response_model=LlmProviderAuthRead)
|
||||
async def update_provider_auth(
|
||||
provider_auth_id: UUID,
|
||||
payload: LlmProviderAuthUpdate,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> LlmProviderAuthRead:
|
||||
"""Patch a provider auth record and sync gateway config."""
|
||||
service = GatewayModelRegistryService(session)
|
||||
return await service.update_provider_auth(
|
||||
provider_auth_id=provider_auth_id,
|
||||
payload=payload,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/provider-auth/{provider_auth_id}", response_model=OkResponse)
|
||||
async def delete_provider_auth(
|
||||
provider_auth_id: UUID,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> OkResponse:
|
||||
"""Delete a provider auth record and sync gateway config."""
|
||||
service = GatewayModelRegistryService(session)
|
||||
await service.delete_provider_auth(provider_auth_id=provider_auth_id, ctx=ctx)
|
||||
return OkResponse()
|
||||
|
||||
|
||||
@router.get("/models", response_model=list[LlmModelRead])
|
||||
async def list_models(
|
||||
gateway_id: UUID | None = GATEWAY_ID_QUERY,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> list[LlmModelRead]:
|
||||
"""List gateway model catalog entries for the active organization."""
|
||||
service = GatewayModelRegistryService(session)
|
||||
return await service.list_models(ctx=ctx, gateway_id=gateway_id)
|
||||
|
||||
|
||||
@router.post("/models", response_model=LlmModelRead)
|
||||
async def create_model(
|
||||
payload: LlmModelCreate,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> LlmModelRead:
|
||||
"""Create a model catalog entry and sync gateway config."""
|
||||
service = GatewayModelRegistryService(session)
|
||||
return await service.create_model(payload=payload, ctx=ctx)
|
||||
|
||||
|
||||
@router.patch("/models/{model_id}", response_model=LlmModelRead)
|
||||
async def update_model(
|
||||
model_id: UUID,
|
||||
payload: LlmModelUpdate,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> LlmModelRead:
|
||||
"""Patch a model catalog entry and sync gateway config."""
|
||||
service = GatewayModelRegistryService(session)
|
||||
return await service.update_model(model_id=model_id, payload=payload, ctx=ctx)
|
||||
|
||||
|
||||
@router.delete("/models/{model_id}", response_model=OkResponse)
|
||||
async def delete_model(
|
||||
model_id: UUID,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> OkResponse:
|
||||
"""Delete a model catalog entry and sync gateway config."""
|
||||
service = GatewayModelRegistryService(session)
|
||||
await service.delete_model(model_id=model_id, ctx=ctx)
|
||||
return OkResponse()
|
||||
|
||||
|
||||
@router.post("/gateways/{gateway_id}/sync", response_model=GatewayModelSyncResult)
|
||||
async def sync_gateway_models(
|
||||
gateway_id: UUID,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> GatewayModelSyncResult:
|
||||
"""Push provider auth + model catalog + agent model links to a gateway."""
|
||||
service = GatewayModelRegistryService(session)
|
||||
gateway = await service.require_gateway(
|
||||
gateway_id=gateway_id,
|
||||
organization_id=ctx.organization.id,
|
||||
)
|
||||
return await service.sync_gateway_config(
|
||||
gateway=gateway,
|
||||
organization_id=ctx.organization.id,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/gateways/{gateway_id}/pull", response_model=GatewayModelPullResult)
|
||||
async def pull_gateway_models(
|
||||
gateway_id: UUID,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> GatewayModelPullResult:
|
||||
"""Pull provider auth + model catalog + agent model links from a gateway."""
|
||||
service = GatewayModelRegistryService(session)
|
||||
gateway = await service.require_gateway(
|
||||
gateway_id=gateway_id,
|
||||
organization_id=ctx.organization.id,
|
||||
)
|
||||
return await service.pull_gateway_config(
|
||||
gateway=gateway,
|
||||
organization_id=ctx.organization.id,
|
||||
)
|
||||
@@ -22,6 +22,7 @@ from app.api.boards import router as boards_router
|
||||
from app.api.gateway import router as gateway_router
|
||||
from app.api.gateways import router as gateways_router
|
||||
from app.api.metrics import router as metrics_router
|
||||
from app.api.model_registry import router as model_registry_router
|
||||
from app.api.organizations import router as organizations_router
|
||||
from app.api.souls_directory import router as souls_directory_router
|
||||
from app.api.tasks import router as tasks_router
|
||||
@@ -98,6 +99,7 @@ api_v1.include_router(activity_router)
|
||||
api_v1.include_router(gateway_router)
|
||||
api_v1.include_router(gateways_router)
|
||||
api_v1.include_router(metrics_router)
|
||||
api_v1.include_router(model_registry_router)
|
||||
api_v1.include_router(organizations_router)
|
||||
api_v1.include_router(souls_directory_router)
|
||||
api_v1.include_router(board_groups_router)
|
||||
|
||||
@@ -10,6 +10,7 @@ from app.models.board_memory import BoardMemory
|
||||
from app.models.board_onboarding import BoardOnboardingSession
|
||||
from app.models.boards import Board
|
||||
from app.models.gateways import Gateway
|
||||
from app.models.llm import LlmModel, LlmProviderAuth
|
||||
from app.models.organization_board_access import OrganizationBoardAccess
|
||||
from app.models.organization_invite_board_access import OrganizationInviteBoardAccess
|
||||
from app.models.organization_invites import OrganizationInvite
|
||||
@@ -31,6 +32,8 @@ __all__ = [
|
||||
"BoardGroup",
|
||||
"Board",
|
||||
"Gateway",
|
||||
"LlmModel",
|
||||
"LlmProviderAuth",
|
||||
"Organization",
|
||||
"OrganizationMember",
|
||||
"OrganizationBoardAccess",
|
||||
|
||||
@@ -44,5 +44,10 @@ class Agent(QueryModel, table=True):
|
||||
delete_confirm_token_hash: str | None = Field(default=None, index=True)
|
||||
last_seen_at: datetime | None = Field(default=None)
|
||||
is_board_lead: bool = Field(default=False, index=True)
|
||||
primary_model_id: UUID | None = Field(default=None, foreign_key="llm_models.id", index=True)
|
||||
fallback_model_ids: list[str] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(JSON),
|
||||
)
|
||||
created_at: datetime = Field(default_factory=utcnow)
|
||||
updated_at: datetime = Field(default_factory=utcnow)
|
||||
|
||||
46
backend/app/models/llm.py
Normal file
46
backend/app/models/llm.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Models for gateway-scoped LLM provider auth and model catalog records."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import JSON, Column
|
||||
from sqlmodel import Field
|
||||
|
||||
from app.core.time import utcnow
|
||||
from app.models.base import QueryModel
|
||||
|
||||
RUNTIME_ANNOTATION_TYPES = (datetime,)
|
||||
|
||||
|
||||
class LlmProviderAuth(QueryModel, table=True):
|
||||
"""Provider auth settings to write into a specific gateway config."""
|
||||
|
||||
__tablename__ = "llm_provider_auth" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
organization_id: UUID = Field(foreign_key="organizations.id", index=True)
|
||||
gateway_id: UUID = Field(foreign_key="gateways.id", index=True)
|
||||
provider: str = Field(index=True)
|
||||
config_path: str
|
||||
secret: str
|
||||
created_at: datetime = Field(default_factory=utcnow)
|
||||
updated_at: datetime = Field(default_factory=utcnow)
|
||||
|
||||
|
||||
class LlmModel(QueryModel, table=True):
|
||||
"""Gateway model catalog entries available for agent assignment."""
|
||||
|
||||
__tablename__ = "llm_models" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
organization_id: UUID = Field(foreign_key="organizations.id", index=True)
|
||||
gateway_id: UUID = Field(foreign_key="gateways.id", index=True)
|
||||
provider: str = Field(index=True)
|
||||
model_id: str = Field(index=True)
|
||||
display_name: str
|
||||
settings: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
created_at: datetime = Field(default_factory=utcnow)
|
||||
updated_at: datetime = Field(default_factory=utcnow)
|
||||
@@ -13,6 +13,16 @@ from app.schemas.board_onboarding import (
|
||||
)
|
||||
from app.schemas.boards import BoardCreate, BoardRead, BoardUpdate
|
||||
from app.schemas.gateways import GatewayCreate, GatewayRead, GatewayUpdate
|
||||
from app.schemas.llm_models import (
|
||||
GatewayModelPullResult,
|
||||
GatewayModelSyncResult,
|
||||
LlmModelCreate,
|
||||
LlmModelRead,
|
||||
LlmModelUpdate,
|
||||
LlmProviderAuthCreate,
|
||||
LlmProviderAuthRead,
|
||||
LlmProviderAuthUpdate,
|
||||
)
|
||||
from app.schemas.metrics import DashboardMetrics
|
||||
from app.schemas.organizations import (
|
||||
OrganizationActiveUpdate,
|
||||
@@ -57,6 +67,14 @@ __all__ = [
|
||||
"GatewayRead",
|
||||
"GatewayUpdate",
|
||||
"DashboardMetrics",
|
||||
"GatewayModelPullResult",
|
||||
"GatewayModelSyncResult",
|
||||
"LlmModelCreate",
|
||||
"LlmModelRead",
|
||||
"LlmModelUpdate",
|
||||
"LlmProviderAuthCreate",
|
||||
"LlmProviderAuthRead",
|
||||
"LlmProviderAuthUpdate",
|
||||
"OrganizationActiveUpdate",
|
||||
"OrganizationCreate",
|
||||
"OrganizationInviteAccept",
|
||||
|
||||
@@ -39,6 +39,27 @@ def _normalize_identity_profile(
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _normalize_model_ids(
|
||||
model_ids: object,
|
||||
) -> list[UUID] | None:
|
||||
if model_ids is None:
|
||||
return None
|
||||
if not isinstance(model_ids, (list, tuple, set)):
|
||||
raise ValueError("fallback_model_ids must be a list")
|
||||
normalized: list[UUID] = []
|
||||
seen: set[UUID] = set()
|
||||
for raw in model_ids:
|
||||
candidate = str(raw).strip()
|
||||
if not candidate:
|
||||
continue
|
||||
model_id = UUID(candidate)
|
||||
if model_id in seen:
|
||||
continue
|
||||
seen.add(model_id)
|
||||
normalized.append(model_id)
|
||||
return normalized or None
|
||||
|
||||
|
||||
class AgentBase(SQLModel):
|
||||
"""Common fields shared by agent create/read/update payloads."""
|
||||
|
||||
@@ -46,6 +67,8 @@ class AgentBase(SQLModel):
|
||||
name: NonEmptyStr
|
||||
status: str = "provisioning"
|
||||
heartbeat_config: dict[str, Any] | None = None
|
||||
primary_model_id: UUID | None = None
|
||||
fallback_model_ids: list[UUID] | None = None
|
||||
identity_profile: dict[str, Any] | None = None
|
||||
identity_template: str | None = None
|
||||
soul_template: str | None = None
|
||||
@@ -70,6 +93,15 @@ class AgentBase(SQLModel):
|
||||
"""Normalize identity-profile values into trimmed string mappings."""
|
||||
return _normalize_identity_profile(value)
|
||||
|
||||
@field_validator("fallback_model_ids", mode="before")
|
||||
@classmethod
|
||||
def normalize_fallback_model_ids(
|
||||
cls,
|
||||
value: object,
|
||||
) -> list[UUID] | None:
|
||||
"""Normalize fallback model ids into ordered UUID values."""
|
||||
return _normalize_model_ids(value)
|
||||
|
||||
|
||||
class AgentCreate(AgentBase):
|
||||
"""Payload for creating a new agent."""
|
||||
@@ -83,6 +115,8 @@ class AgentUpdate(SQLModel):
|
||||
name: NonEmptyStr | None = None
|
||||
status: str | None = None
|
||||
heartbeat_config: dict[str, Any] | None = None
|
||||
primary_model_id: UUID | None = None
|
||||
fallback_model_ids: list[UUID] | None = None
|
||||
identity_profile: dict[str, Any] | None = None
|
||||
identity_template: str | None = None
|
||||
soul_template: str | None = None
|
||||
@@ -107,6 +141,15 @@ class AgentUpdate(SQLModel):
|
||||
"""Normalize identity-profile values into trimmed string mappings."""
|
||||
return _normalize_identity_profile(value)
|
||||
|
||||
@field_validator("fallback_model_ids", mode="before")
|
||||
@classmethod
|
||||
def normalize_fallback_model_ids(
|
||||
cls,
|
||||
value: object,
|
||||
) -> list[UUID] | None:
|
||||
"""Normalize fallback model ids into ordered UUID values."""
|
||||
return _normalize_model_ids(value)
|
||||
|
||||
|
||||
class AgentRead(AgentBase):
|
||||
"""Public agent representation returned by the API."""
|
||||
|
||||
167
backend/app/schemas/llm_models.py
Normal file
167
backend/app/schemas/llm_models.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""Schemas for LLM provider auth, model catalog, and gateway sync payloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import field_validator
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from app.schemas.common import NonEmptyStr
|
||||
|
||||
RUNTIME_ANNOTATION_TYPES = (datetime, UUID, NonEmptyStr)
|
||||
|
||||
|
||||
def _normalize_provider(value: object) -> str | object:
|
||||
if isinstance(value, str):
|
||||
normalized = value.strip().lower()
|
||||
return normalized or value
|
||||
return value
|
||||
|
||||
|
||||
def _normalize_config_path(value: object) -> str | object:
|
||||
if isinstance(value, str):
|
||||
normalized = value.strip()
|
||||
return normalized or value
|
||||
return value
|
||||
|
||||
|
||||
def _default_provider_config_path(provider: str) -> str:
|
||||
return f"providers.{provider}.apiKey"
|
||||
|
||||
|
||||
class LlmProviderAuthBase(SQLModel):
|
||||
"""Shared provider auth fields."""
|
||||
|
||||
gateway_id: UUID
|
||||
provider: NonEmptyStr
|
||||
config_path: NonEmptyStr | None = None
|
||||
|
||||
@field_validator("provider", mode="before")
|
||||
@classmethod
|
||||
def normalize_provider(cls, value: object) -> str | object:
|
||||
return _normalize_provider(value)
|
||||
|
||||
@field_validator("config_path", mode="before")
|
||||
@classmethod
|
||||
def normalize_config_path(cls, value: object) -> str | object:
|
||||
return _normalize_config_path(value)
|
||||
|
||||
|
||||
class LlmProviderAuthCreate(LlmProviderAuthBase):
|
||||
"""Payload used to create a provider auth record."""
|
||||
|
||||
secret: NonEmptyStr
|
||||
|
||||
|
||||
class LlmProviderAuthUpdate(SQLModel):
|
||||
"""Payload used to patch an existing provider auth record."""
|
||||
|
||||
provider: NonEmptyStr | None = None
|
||||
config_path: NonEmptyStr | None = None
|
||||
secret: NonEmptyStr | None = None
|
||||
|
||||
@field_validator("provider", mode="before")
|
||||
@classmethod
|
||||
def normalize_provider(cls, value: object) -> str | object:
|
||||
return _normalize_provider(value)
|
||||
|
||||
@field_validator("config_path", mode="before")
|
||||
@classmethod
|
||||
def normalize_config_path(cls, value: object) -> str | object:
|
||||
return _normalize_config_path(value)
|
||||
|
||||
|
||||
class LlmProviderAuthRead(SQLModel):
|
||||
"""Public provider auth payload (secret value is never returned)."""
|
||||
|
||||
id: UUID
|
||||
organization_id: UUID
|
||||
gateway_id: UUID
|
||||
provider: str
|
||||
config_path: str
|
||||
has_secret: bool = True
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class LlmModelBase(SQLModel):
|
||||
"""Shared gateway model catalog fields."""
|
||||
|
||||
gateway_id: UUID
|
||||
provider: NonEmptyStr
|
||||
model_id: NonEmptyStr
|
||||
display_name: NonEmptyStr
|
||||
settings: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("provider", mode="before")
|
||||
@classmethod
|
||||
def normalize_provider(cls, value: object) -> str | object:
|
||||
return _normalize_provider(value)
|
||||
|
||||
|
||||
class LlmModelCreate(LlmModelBase):
|
||||
"""Payload used to create a model catalog entry."""
|
||||
|
||||
|
||||
class LlmModelUpdate(SQLModel):
|
||||
"""Payload used to patch an existing model catalog entry."""
|
||||
|
||||
provider: NonEmptyStr | None = None
|
||||
model_id: NonEmptyStr | None = None
|
||||
display_name: NonEmptyStr | None = None
|
||||
settings: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("provider", mode="before")
|
||||
@classmethod
|
||||
def normalize_provider(cls, value: object) -> str | object:
|
||||
return _normalize_provider(value)
|
||||
|
||||
|
||||
class LlmModelRead(SQLModel):
|
||||
"""Public model catalog entry payload."""
|
||||
|
||||
id: UUID
|
||||
organization_id: UUID
|
||||
gateway_id: UUID
|
||||
provider: str
|
||||
model_id: str
|
||||
display_name: str
|
||||
settings: dict[str, Any] | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class GatewayModelSyncResult(SQLModel):
|
||||
"""Summary of model/provider config sync operations for a gateway."""
|
||||
|
||||
gateway_id: UUID
|
||||
provider_auth_patched: int
|
||||
model_catalog_patched: int
|
||||
agent_models_patched: int
|
||||
sessions_patched: int
|
||||
errors: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class GatewayModelPullResult(SQLModel):
|
||||
"""Summary of model/provider config pull operations for a gateway."""
|
||||
|
||||
gateway_id: UUID
|
||||
provider_auth_imported: int
|
||||
model_catalog_imported: int
|
||||
agent_models_imported: int
|
||||
errors: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GatewayModelPullResult",
|
||||
"GatewayModelSyncResult",
|
||||
"LlmModelCreate",
|
||||
"LlmModelRead",
|
||||
"LlmModelUpdate",
|
||||
"LlmProviderAuthCreate",
|
||||
"LlmProviderAuthRead",
|
||||
"LlmProviderAuthUpdate",
|
||||
]
|
||||
1002
backend/app/services/openclaw/model_registry_service.py
Normal file
1002
backend/app/services/openclaw/model_registry_service.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -492,10 +492,32 @@ async def _gateway_config_agent_list(
|
||||
msg = "config.get returned invalid payload"
|
||||
raise OpenClawGatewayError(msg)
|
||||
|
||||
data = cfg.get("config") or cfg.get("parsed") or {}
|
||||
if not isinstance(data, dict):
|
||||
msg = "config.get returned invalid config"
|
||||
raise OpenClawGatewayError(msg)
|
||||
# Prefer parsed object over raw serialized config to support older gateways.
|
||||
raw_parsed = cfg.get("parsed")
|
||||
raw_config = cfg.get("config")
|
||||
data: dict[str, Any]
|
||||
|
||||
def _parse_json_config(raw: str) -> dict[str, Any]:
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError as exc:
|
||||
msg = "config.get returned invalid config"
|
||||
raise OpenClawGatewayError(msg) from exc
|
||||
if not isinstance(parsed, dict):
|
||||
msg = "config.get returned invalid config"
|
||||
raise OpenClawGatewayError(msg)
|
||||
return parsed
|
||||
|
||||
if isinstance(raw_parsed, dict):
|
||||
data = raw_parsed
|
||||
elif isinstance(raw_config, dict):
|
||||
data = raw_config
|
||||
elif isinstance(raw_parsed, str) and raw_parsed.strip():
|
||||
data = _parse_json_config(raw_parsed)
|
||||
elif isinstance(raw_config, str) and raw_config.strip():
|
||||
data = _parse_json_config(raw_config)
|
||||
else:
|
||||
data = {}
|
||||
|
||||
agents_section = data.get("agents") or {}
|
||||
agents_list = agents_section.get("list") or []
|
||||
|
||||
@@ -32,6 +32,7 @@ from app.models.agents import Agent
|
||||
from app.models.board_memory import BoardMemory
|
||||
from app.models.boards import Board
|
||||
from app.models.gateways import Gateway
|
||||
from app.models.llm import LlmModel
|
||||
from app.models.organizations import Organization
|
||||
from app.models.tasks import Task
|
||||
from app.schemas.agents import (
|
||||
@@ -72,6 +73,7 @@ from app.services.openclaw.internal.session_keys import (
|
||||
board_agent_session_key,
|
||||
board_lead_session_key,
|
||||
)
|
||||
from app.services.openclaw.model_registry_service import GatewayModelRegistryService
|
||||
from app.services.openclaw.policies import OpenClawAuthorizationPolicy
|
||||
from app.services.openclaw.provisioning import (
|
||||
OpenClawGatewayControlPlane,
|
||||
@@ -959,6 +961,71 @@ class AgentLifecycleService(OpenClawDBService):
|
||||
detail="An agent with this name already exists in this gateway workspace.",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalized_fallback_ids(value: object) -> list[UUID]:
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
normalized: list[UUID] = []
|
||||
seen: set[UUID] = set()
|
||||
for raw in value:
|
||||
candidate = str(raw).strip()
|
||||
if not candidate:
|
||||
continue
|
||||
try:
|
||||
model_id = UUID(candidate)
|
||||
except ValueError:
|
||||
continue
|
||||
if model_id in seen:
|
||||
continue
|
||||
seen.add(model_id)
|
||||
normalized.append(model_id)
|
||||
return normalized
|
||||
|
||||
async def normalize_agent_model_assignments(
|
||||
self,
|
||||
*,
|
||||
gateway_id: UUID,
|
||||
primary_model_id: UUID | None,
|
||||
fallback_model_ids: list[UUID],
|
||||
) -> tuple[UUID | None, list[str] | None]:
|
||||
if primary_model_id is None and not fallback_model_ids:
|
||||
return None, None
|
||||
|
||||
candidate_ids: list[UUID] = []
|
||||
if primary_model_id is not None:
|
||||
candidate_ids.append(primary_model_id)
|
||||
candidate_ids.extend(fallback_model_ids)
|
||||
unique_ids = list(dict.fromkeys(candidate_ids))
|
||||
statement = (
|
||||
select(LlmModel.id)
|
||||
.where(col(LlmModel.gateway_id) == gateway_id)
|
||||
.where(col(LlmModel.id).in_(unique_ids))
|
||||
)
|
||||
with self.session.no_autoflush:
|
||||
valid_ids = set(await self.session.exec(statement))
|
||||
missing = [value for value in unique_ids if value not in valid_ids]
|
||||
if missing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Model assignment includes model ids not in the agent gateway catalog.",
|
||||
)
|
||||
|
||||
filtered_fallback: list[str] = []
|
||||
for fallback_id in fallback_model_ids:
|
||||
if primary_model_id is not None and fallback_id == primary_model_id:
|
||||
continue
|
||||
value = str(fallback_id)
|
||||
if value in filtered_fallback:
|
||||
continue
|
||||
filtered_fallback.append(value)
|
||||
return primary_model_id, (filtered_fallback or None)
|
||||
|
||||
async def sync_gateway_agent_models(self, *, gateway: Gateway) -> None:
|
||||
await GatewayModelRegistryService(self.session).sync_gateway_config(
|
||||
gateway=gateway,
|
||||
organization_id=gateway.organization_id,
|
||||
)
|
||||
|
||||
async def persist_new_agent(
|
||||
self,
|
||||
*,
|
||||
@@ -1177,6 +1244,13 @@ class AgentLifecycleService(OpenClawDBService):
|
||||
detail="Board gateway_id is required",
|
||||
)
|
||||
updates["gateway_id"] = board.gateway_id
|
||||
if "primary_model_id" in updates:
|
||||
primary_value = updates["primary_model_id"]
|
||||
if primary_value is not None and not isinstance(primary_value, UUID):
|
||||
updates["primary_model_id"] = UUID(str(primary_value))
|
||||
if "fallback_model_ids" in updates:
|
||||
normalized_fallback = self._normalized_fallback_ids(updates["fallback_model_ids"])
|
||||
updates["fallback_model_ids"] = [str(model_id) for model_id in normalized_fallback] or None
|
||||
for key, value in updates.items():
|
||||
setattr(agent, key, value)
|
||||
|
||||
@@ -1192,6 +1266,19 @@ class AgentLifecycleService(OpenClawDBService):
|
||||
detail="Board gateway_id is required",
|
||||
)
|
||||
agent.gateway_id = board.gateway_id
|
||||
|
||||
primary_model_id = agent.primary_model_id
|
||||
if primary_model_id is not None and not isinstance(primary_model_id, UUID):
|
||||
primary_model_id = UUID(str(primary_model_id))
|
||||
fallback_model_ids = self._normalized_fallback_ids(agent.fallback_model_ids)
|
||||
normalized_primary, normalized_fallback = await self.normalize_agent_model_assignments(
|
||||
gateway_id=agent.gateway_id,
|
||||
primary_model_id=primary_model_id if isinstance(primary_model_id, UUID) else None,
|
||||
fallback_model_ids=fallback_model_ids,
|
||||
)
|
||||
agent.primary_model_id = normalized_primary
|
||||
agent.fallback_model_ids = normalized_fallback
|
||||
|
||||
agent.updated_at = utcnow()
|
||||
if agent.heartbeat_config is None:
|
||||
agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy()
|
||||
@@ -1487,6 +1574,17 @@ class AgentLifecycleService(OpenClawDBService):
|
||||
gateway, _client_config = await self.require_gateway(board)
|
||||
data = payload.model_dump()
|
||||
data["gateway_id"] = gateway.id
|
||||
primary_model_id = data.get("primary_model_id")
|
||||
if primary_model_id is not None and not isinstance(primary_model_id, UUID):
|
||||
primary_model_id = UUID(str(primary_model_id))
|
||||
fallback_model_ids = self._normalized_fallback_ids(data.get("fallback_model_ids"))
|
||||
normalized_primary, normalized_fallback = await self.normalize_agent_model_assignments(
|
||||
gateway_id=gateway.id,
|
||||
primary_model_id=primary_model_id if isinstance(primary_model_id, UUID) else None,
|
||||
fallback_model_ids=fallback_model_ids,
|
||||
)
|
||||
data["primary_model_id"] = normalized_primary
|
||||
data["fallback_model_ids"] = normalized_fallback
|
||||
requested_name = (data.get("name") or "").strip()
|
||||
await self.ensure_unique_agent_name(
|
||||
board=board,
|
||||
@@ -1502,6 +1600,8 @@ class AgentLifecycleService(OpenClawDBService):
|
||||
user=actor.user if actor.actor_type == "user" else None,
|
||||
force_bootstrap=False,
|
||||
)
|
||||
if agent.primary_model_id is not None or agent.fallback_model_ids:
|
||||
await self.sync_gateway_agent_models(gateway=gateway)
|
||||
self.logger.info("agent.create.success agent_id=%s board_id=%s", agent.id, board.id)
|
||||
return self.to_agent_read(self.with_computed_status(agent))
|
||||
|
||||
@@ -1535,6 +1635,7 @@ class AgentLifecycleService(OpenClawDBService):
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
await self.require_agent_access(agent=agent, ctx=options.context, write=True)
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
sync_model_assignments = "primary_model_id" in updates or "fallback_model_ids" in updates
|
||||
make_main = updates.pop("is_gateway_main", None)
|
||||
await self.validate_agent_update_inputs(
|
||||
ctx=options.context,
|
||||
@@ -1568,6 +1669,8 @@ class AgentLifecycleService(OpenClawDBService):
|
||||
agent=agent,
|
||||
request=provision_request,
|
||||
)
|
||||
if sync_model_assignments or agent.primary_model_id is not None or agent.fallback_model_ids:
|
||||
await self.sync_gateway_agent_models(gateway=target.gateway)
|
||||
self.logger.info("agent.update.success agent_id=%s", agent.id)
|
||||
return self.to_agent_read(self.with_computed_status(agent))
|
||||
|
||||
|
||||
@@ -0,0 +1,285 @@
|
||||
"""add llm model registry
|
||||
|
||||
Revision ID: 9a3fb1158c2d
|
||||
Revises: f4d2b649e93a
|
||||
Create Date: 2026-02-11 21:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9a3fb1158c2d"
|
||||
down_revision = "f4d2b649e93a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _has_index(inspector: sa.Inspector, table: str, index_name: str) -> bool:
|
||||
return any(item.get("name") == index_name for item in inspector.get_indexes(table))
|
||||
|
||||
|
||||
def _has_unique(
|
||||
inspector: sa.Inspector,
|
||||
table: str,
|
||||
*,
|
||||
name: str | None = None,
|
||||
columns: tuple[str, ...] | None = None,
|
||||
) -> bool:
|
||||
unique_constraints = inspector.get_unique_constraints(table)
|
||||
for item in unique_constraints:
|
||||
if name and item.get("name") == name:
|
||||
return True
|
||||
if columns and tuple(item.get("column_names") or ()) == columns:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _column_names(inspector: sa.Inspector, table: str) -> set[str]:
|
||||
return {item["name"] for item in inspector.get_columns(table)}
|
||||
|
||||
|
||||
def _has_foreign_key(
|
||||
inspector: sa.Inspector,
|
||||
table: str,
|
||||
*,
|
||||
constrained_columns: tuple[str, ...],
|
||||
referred_table: str,
|
||||
) -> bool:
|
||||
for item in inspector.get_foreign_keys(table):
|
||||
if tuple(item.get("constrained_columns") or ()) != constrained_columns:
|
||||
continue
|
||||
if item.get("referred_table") != referred_table:
|
||||
continue
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
if not inspector.has_table("llm_provider_auth"):
|
||||
op.create_table(
|
||||
"llm_provider_auth",
|
||||
sa.Column("id", sa.Uuid(), nullable=False),
|
||||
sa.Column("organization_id", sa.Uuid(), nullable=False),
|
||||
sa.Column("gateway_id", sa.Uuid(), nullable=False),
|
||||
sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("config_path", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("secret", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["gateway_id"], ["gateways.id"]),
|
||||
sa.ForeignKeyConstraint(["organization_id"], ["organizations.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"gateway_id",
|
||||
"provider",
|
||||
"config_path",
|
||||
name="uq_llm_provider_auth_gateway_provider_path",
|
||||
),
|
||||
)
|
||||
inspector = sa.inspect(bind)
|
||||
else:
|
||||
existing_columns = _column_names(inspector, "llm_provider_auth")
|
||||
if "config_path" not in existing_columns:
|
||||
op.add_column(
|
||||
"llm_provider_auth",
|
||||
sa.Column(
|
||||
"config_path",
|
||||
sqlmodel.sql.sqltypes.AutoString(),
|
||||
nullable=False,
|
||||
server_default="providers.openai.apiKey",
|
||||
),
|
||||
)
|
||||
op.alter_column("llm_provider_auth", "config_path", server_default=None)
|
||||
inspector = sa.inspect(bind)
|
||||
if not _has_unique(
|
||||
inspector,
|
||||
"llm_provider_auth",
|
||||
name="uq_llm_provider_auth_gateway_provider_path",
|
||||
columns=("gateway_id", "provider", "config_path"),
|
||||
):
|
||||
op.create_unique_constraint(
|
||||
"uq_llm_provider_auth_gateway_provider_path",
|
||||
"llm_provider_auth",
|
||||
["gateway_id", "provider", "config_path"],
|
||||
)
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
if not _has_index(inspector, "llm_provider_auth", op.f("ix_llm_provider_auth_gateway_id")):
|
||||
op.create_index(
|
||||
op.f("ix_llm_provider_auth_gateway_id"),
|
||||
"llm_provider_auth",
|
||||
["gateway_id"],
|
||||
unique=False,
|
||||
)
|
||||
inspector = sa.inspect(bind)
|
||||
if not _has_index(
|
||||
inspector,
|
||||
"llm_provider_auth",
|
||||
op.f("ix_llm_provider_auth_organization_id"),
|
||||
):
|
||||
op.create_index(
|
||||
op.f("ix_llm_provider_auth_organization_id"),
|
||||
"llm_provider_auth",
|
||||
["organization_id"],
|
||||
unique=False,
|
||||
)
|
||||
inspector = sa.inspect(bind)
|
||||
if not _has_index(inspector, "llm_provider_auth", op.f("ix_llm_provider_auth_provider")):
|
||||
op.create_index(
|
||||
op.f("ix_llm_provider_auth_provider"),
|
||||
"llm_provider_auth",
|
||||
["provider"],
|
||||
unique=False,
|
||||
)
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
if not inspector.has_table("llm_models"):
|
||||
op.create_table(
|
||||
"llm_models",
|
||||
sa.Column("id", sa.Uuid(), nullable=False),
|
||||
sa.Column("organization_id", sa.Uuid(), nullable=False),
|
||||
sa.Column("gateway_id", sa.Uuid(), nullable=False),
|
||||
sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("model_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("display_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("settings", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["gateway_id"], ["gateways.id"]),
|
||||
sa.ForeignKeyConstraint(["organization_id"], ["organizations.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("gateway_id", "model_id", name="uq_llm_models_gateway_model_id"),
|
||||
)
|
||||
inspector = sa.inspect(bind)
|
||||
elif not _has_unique(
|
||||
inspector,
|
||||
"llm_models",
|
||||
name="uq_llm_models_gateway_model_id",
|
||||
columns=("gateway_id", "model_id"),
|
||||
):
|
||||
op.create_unique_constraint(
|
||||
"uq_llm_models_gateway_model_id",
|
||||
"llm_models",
|
||||
["gateway_id", "model_id"],
|
||||
)
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
if not _has_index(inspector, "llm_models", op.f("ix_llm_models_gateway_id")):
|
||||
op.create_index(
|
||||
op.f("ix_llm_models_gateway_id"),
|
||||
"llm_models",
|
||||
["gateway_id"],
|
||||
unique=False,
|
||||
)
|
||||
inspector = sa.inspect(bind)
|
||||
if not _has_index(inspector, "llm_models", op.f("ix_llm_models_model_id")):
|
||||
op.create_index(
|
||||
op.f("ix_llm_models_model_id"),
|
||||
"llm_models",
|
||||
["model_id"],
|
||||
unique=False,
|
||||
)
|
||||
inspector = sa.inspect(bind)
|
||||
if not _has_index(inspector, "llm_models", op.f("ix_llm_models_organization_id")):
|
||||
op.create_index(
|
||||
op.f("ix_llm_models_organization_id"),
|
||||
"llm_models",
|
||||
["organization_id"],
|
||||
unique=False,
|
||||
)
|
||||
inspector = sa.inspect(bind)
|
||||
if not _has_index(inspector, "llm_models", op.f("ix_llm_models_provider")):
|
||||
op.create_index(
|
||||
op.f("ix_llm_models_provider"),
|
||||
"llm_models",
|
||||
["provider"],
|
||||
unique=False,
|
||||
)
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
agent_columns = _column_names(inspector, "agents")
|
||||
if "primary_model_id" not in agent_columns:
|
||||
op.add_column("agents", sa.Column("primary_model_id", sa.Uuid(), nullable=True))
|
||||
inspector = sa.inspect(bind)
|
||||
if "fallback_model_ids" not in agent_columns:
|
||||
op.add_column("agents", sa.Column("fallback_model_ids", sa.JSON(), nullable=True))
|
||||
inspector = sa.inspect(bind)
|
||||
if not _has_index(inspector, "agents", op.f("ix_agents_primary_model_id")):
|
||||
op.create_index(op.f("ix_agents_primary_model_id"), "agents", ["primary_model_id"], unique=False)
|
||||
inspector = sa.inspect(bind)
|
||||
if not _has_foreign_key(
|
||||
inspector,
|
||||
"agents",
|
||||
constrained_columns=("primary_model_id",),
|
||||
referred_table="llm_models",
|
||||
):
|
||||
op.create_foreign_key(
|
||||
"fk_agents_primary_model_id_llm_models",
|
||||
"agents",
|
||||
"llm_models",
|
||||
["primary_model_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
if inspector.has_table("agents"):
|
||||
for fk in inspector.get_foreign_keys("agents"):
|
||||
if tuple(fk.get("constrained_columns") or ()) != ("primary_model_id",):
|
||||
continue
|
||||
if fk.get("referred_table") != "llm_models":
|
||||
continue
|
||||
fk_name = fk.get("name")
|
||||
if fk_name:
|
||||
op.drop_constraint(fk_name, "agents", type_="foreignkey")
|
||||
inspector = sa.inspect(bind)
|
||||
if _has_index(inspector, "agents", op.f("ix_agents_primary_model_id")):
|
||||
op.drop_index(op.f("ix_agents_primary_model_id"), table_name="agents")
|
||||
agent_columns = _column_names(inspector, "agents")
|
||||
if "fallback_model_ids" in agent_columns:
|
||||
op.drop_column("agents", "fallback_model_ids")
|
||||
if "primary_model_id" in agent_columns:
|
||||
op.drop_column("agents", "primary_model_id")
|
||||
|
||||
inspector = sa.inspect(bind)
|
||||
if inspector.has_table("llm_models"):
|
||||
if _has_index(inspector, "llm_models", op.f("ix_llm_models_provider")):
|
||||
op.drop_index(op.f("ix_llm_models_provider"), table_name="llm_models")
|
||||
if _has_index(inspector, "llm_models", op.f("ix_llm_models_organization_id")):
|
||||
op.drop_index(op.f("ix_llm_models_organization_id"), table_name="llm_models")
|
||||
if _has_index(inspector, "llm_models", op.f("ix_llm_models_model_id")):
|
||||
op.drop_index(op.f("ix_llm_models_model_id"), table_name="llm_models")
|
||||
if _has_index(inspector, "llm_models", op.f("ix_llm_models_gateway_id")):
|
||||
op.drop_index(op.f("ix_llm_models_gateway_id"), table_name="llm_models")
|
||||
op.drop_table("llm_models")
|
||||
|
||||
inspector = sa.inspect(bind)
|
||||
if inspector.has_table("llm_provider_auth"):
|
||||
if _has_index(inspector, "llm_provider_auth", op.f("ix_llm_provider_auth_provider")):
|
||||
op.drop_index(op.f("ix_llm_provider_auth_provider"), table_name="llm_provider_auth")
|
||||
if _has_index(
|
||||
inspector,
|
||||
"llm_provider_auth",
|
||||
op.f("ix_llm_provider_auth_organization_id"),
|
||||
):
|
||||
op.drop_index(
|
||||
op.f("ix_llm_provider_auth_organization_id"),
|
||||
table_name="llm_provider_auth",
|
||||
)
|
||||
if _has_index(inspector, "llm_provider_auth", op.f("ix_llm_provider_auth_gateway_id")):
|
||||
op.drop_index(
|
||||
op.f("ix_llm_provider_auth_gateway_id"),
|
||||
table_name="llm_provider_auth",
|
||||
)
|
||||
op.drop_table("llm_provider_auth")
|
||||
105
backend/tests/test_agent_model_assignment_updates.py
Normal file
105
backend/tests/test_agent_model_assignment_updates.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# ruff: noqa: S101
|
||||
"""Regression tests for agent model-assignment update normalization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from app.services.openclaw.provisioning_db import AgentLifecycleService
|
||||
|
||||
|
||||
class _NoAutoflush:
|
||||
def __init__(self, session: "_SessionStub") -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> None:
|
||||
self._session.in_no_autoflush = True
|
||||
return None
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> bool:
|
||||
self._session.in_no_autoflush = False
|
||||
return False
|
||||
|
||||
|
||||
class _SessionStub:
|
||||
def __init__(self, valid_ids: set[UUID]) -> None:
|
||||
self.valid_ids = valid_ids
|
||||
self.in_no_autoflush = False
|
||||
self.commits = 0
|
||||
|
||||
@property
|
||||
def no_autoflush(self) -> _NoAutoflush:
|
||||
return _NoAutoflush(self)
|
||||
|
||||
async def exec(self, _statement: Any) -> list[UUID]:
|
||||
if not self.in_no_autoflush:
|
||||
raise AssertionError("Expected normalize query to run under no_autoflush.")
|
||||
return list(self.valid_ids)
|
||||
|
||||
def add(self, _model: Any) -> None:
|
||||
return None
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.commits += 1
|
||||
|
||||
async def refresh(self, _model: Any) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _AgentStub:
|
||||
gateway_id: UUID
|
||||
board_id: UUID | None = None
|
||||
is_board_lead: bool = False
|
||||
openclaw_session_id: str | None = None
|
||||
primary_model_id: UUID | None = None
|
||||
fallback_model_ids: list[str] | None = None
|
||||
updated_at: datetime | None = None
|
||||
heartbeat_config: dict[str, Any] | None = field(
|
||||
default_factory=lambda: {"interval_seconds": 5},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_agent_model_assignments_uses_no_autoflush() -> None:
|
||||
primary = uuid4()
|
||||
fallback = uuid4()
|
||||
session = _SessionStub({primary, fallback})
|
||||
service = AgentLifecycleService(session) # type: ignore[arg-type]
|
||||
|
||||
normalized_primary, normalized_fallback = await service.normalize_agent_model_assignments(
|
||||
gateway_id=uuid4(),
|
||||
primary_model_id=primary,
|
||||
fallback_model_ids=[primary, fallback],
|
||||
)
|
||||
|
||||
assert normalized_primary == primary
|
||||
assert normalized_fallback == [str(fallback)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_agent_update_mutations_coerces_fallback_ids_to_strings(monkeypatch) -> None:
|
||||
primary = uuid4()
|
||||
fallback = uuid4()
|
||||
session = _SessionStub({primary, fallback})
|
||||
service = AgentLifecycleService(session) # type: ignore[arg-type]
|
||||
monkeypatch.setattr(service, "get_main_agent_gateway", AsyncMock(return_value=None))
|
||||
|
||||
agent = _AgentStub(gateway_id=uuid4())
|
||||
updates: dict[str, Any] = {
|
||||
"primary_model_id": primary,
|
||||
"fallback_model_ids": [primary, fallback, fallback],
|
||||
}
|
||||
|
||||
await service.apply_agent_update_mutations(agent=agent, updates=updates, make_main=None) # type: ignore[arg-type]
|
||||
|
||||
assert updates["fallback_model_ids"] == [str(primary), str(fallback)]
|
||||
assert agent.primary_model_id == primary
|
||||
assert agent.fallback_model_ids == [str(fallback)]
|
||||
assert session.commits == 1
|
||||
28
backend/tests/test_agent_schema_models.py
Normal file
28
backend/tests/test_agent_schema_models.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# ruff: noqa: S101
|
||||
"""Tests for agent model-assignment schema normalization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.schemas.agents import AgentCreate
|
||||
|
||||
|
||||
def test_agent_create_normalizes_fallback_model_ids() -> None:
|
||||
model_a = uuid4()
|
||||
model_b = uuid4()
|
||||
|
||||
payload = AgentCreate(
|
||||
name="Worker",
|
||||
fallback_model_ids=[str(model_a), str(model_b), str(model_a)],
|
||||
)
|
||||
|
||||
assert payload.fallback_model_ids == [model_a, model_b]
|
||||
|
||||
|
||||
def test_agent_create_rejects_non_list_fallback_model_ids() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentCreate(name="Worker", fallback_model_ids="not-a-list")
|
||||
110
backend/tests/test_model_registry_pull_helpers.py
Normal file
110
backend/tests/test_model_registry_pull_helpers.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# ruff: noqa: S101
|
||||
"""Tests for gateway model-registry pull helper normalization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.services.openclaw.model_registry_service import (
|
||||
_extract_config_data,
|
||||
_get_nested_path,
|
||||
_infer_provider_for_model,
|
||||
_model_config,
|
||||
_model_settings,
|
||||
_normalize_provider,
|
||||
_parse_agent_model_value,
|
||||
)
|
||||
|
||||
|
||||
def test_get_nested_path_resolves_existing_value() -> None:
|
||||
source = {"providers": {"openai": {"apiKey": "sk-test"}}}
|
||||
|
||||
assert _get_nested_path(source, ["providers", "openai", "apiKey"]) == "sk-test"
|
||||
assert _get_nested_path(source, ["providers", "anthropic", "apiKey"]) is None
|
||||
|
||||
|
||||
def test_normalize_provider_trims_and_lowercases() -> None:
|
||||
assert _normalize_provider(" OpenAI ") == "openai"
|
||||
assert _normalize_provider("") is None
|
||||
assert _normalize_provider(123) is None
|
||||
|
||||
|
||||
def test_infer_provider_for_model_prefers_prefix_delimiter() -> None:
|
||||
assert _infer_provider_for_model("openai/gpt-5") == "openai"
|
||||
assert _infer_provider_for_model("anthropic:claude-sonnet") == "anthropic"
|
||||
assert _infer_provider_for_model("gpt-5") == "unknown"
|
||||
|
||||
|
||||
def test_model_settings_only_accepts_dict_payloads() -> None:
|
||||
settings = _model_settings({"provider": "openai", "temperature": 0.2})
|
||||
|
||||
assert settings == {"provider": "openai", "temperature": 0.2}
|
||||
assert _model_settings("not-a-dict") is None
|
||||
|
||||
|
||||
def test_parse_agent_model_value_normalizes_primary_and_fallbacks() -> None:
|
||||
primary, fallback = _parse_agent_model_value(
|
||||
{
|
||||
"primary": " openai/gpt-5 ",
|
||||
"fallbacks": [
|
||||
"openai/gpt-4.1",
|
||||
"openai/gpt-5",
|
||||
"openai/gpt-4.1",
|
||||
" ",
|
||||
123,
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert primary == "openai/gpt-5"
|
||||
assert fallback == ["openai/gpt-4.1"]
|
||||
|
||||
|
||||
def test_parse_agent_model_value_accepts_legacy_fallback_key() -> None:
|
||||
primary, fallback = _parse_agent_model_value(
|
||||
{
|
||||
"primary": "openai/gpt-5",
|
||||
"fallback": ["openai/gpt-4.1", "openai/gpt-4.1"],
|
||||
},
|
||||
)
|
||||
|
||||
assert primary == "openai/gpt-5"
|
||||
assert fallback == ["openai/gpt-4.1"]
|
||||
|
||||
|
||||
def test_parse_agent_model_value_accepts_string_primary() -> None:
|
||||
primary, fallback = _parse_agent_model_value(" openai/gpt-5 ")
|
||||
|
||||
assert primary == "openai/gpt-5"
|
||||
assert fallback == []
|
||||
|
||||
|
||||
def test_model_config_uses_fallbacks_key() -> None:
|
||||
assert _model_config("openai/gpt-5", ["openai/gpt-4.1"]) == {
|
||||
"primary": "openai/gpt-5",
|
||||
"fallbacks": ["openai/gpt-4.1"],
|
||||
}
|
||||
|
||||
|
||||
def test_extract_config_data_prefers_parsed_when_config_is_raw_string() -> None:
|
||||
config_data, base_hash = _extract_config_data(
|
||||
{
|
||||
"config": '{"agents":{"list":[{"id":"a1"}]}}',
|
||||
"parsed": {"agents": {"list": [{"id": "a1"}]}},
|
||||
"hash": "abc123",
|
||||
},
|
||||
)
|
||||
|
||||
assert isinstance(config_data, dict)
|
||||
assert config_data.get("agents") == {"list": [{"id": "a1"}]}
|
||||
assert base_hash == "abc123"
|
||||
|
||||
|
||||
def test_extract_config_data_parses_json_string_when_parsed_absent() -> None:
|
||||
config_data, base_hash = _extract_config_data(
|
||||
{
|
||||
"config": '{"providers":{"openai":{"apiKey":"sk-test"}}}',
|
||||
"hash": "def456",
|
||||
},
|
||||
)
|
||||
|
||||
assert config_data.get("providers") == {"openai": {"apiKey": "sk-test"}}
|
||||
assert base_hash == "def456"
|
||||
Reference in New Issue
Block a user