feat(models): implement model routing and management pages with dynamic imports

This commit is contained in:
Abhimanyu Saharan
2026-02-12 00:16:18 +05:30
parent d5067e443b
commit dc7906a224
52 changed files with 6470 additions and 10 deletions

View File

@@ -0,0 +1,165 @@
"""API routes for gateway model registry and provider-auth management."""
from __future__ import annotations
from typing import TYPE_CHECKING
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from app.api.deps import require_org_admin
from app.db.session import get_session
from app.schemas.common import OkResponse
from app.schemas.llm_models import (
GatewayModelPullResult,
GatewayModelSyncResult,
LlmModelCreate,
LlmModelRead,
LlmModelUpdate,
LlmProviderAuthCreate,
LlmProviderAuthRead,
LlmProviderAuthUpdate,
)
from app.services.openclaw.model_registry_service import GatewayModelRegistryService
from app.services.organizations import OrganizationContext
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter(prefix="/model-registry", tags=["model-registry"])
SESSION_DEP = Depends(get_session)
ORG_ADMIN_DEP = Depends(require_org_admin)
GATEWAY_ID_QUERY = Query(default=None)
@router.get("/provider-auth", response_model=list[LlmProviderAuthRead])
async def list_provider_auth(
gateway_id: UUID | None = GATEWAY_ID_QUERY,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> list[LlmProviderAuthRead]:
"""List provider auth records for the active organization."""
service = GatewayModelRegistryService(session)
return await service.list_provider_auth(ctx=ctx, gateway_id=gateway_id)
@router.post("/provider-auth", response_model=LlmProviderAuthRead)
async def create_provider_auth(
payload: LlmProviderAuthCreate,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> LlmProviderAuthRead:
"""Create a provider auth record and sync gateway config."""
service = GatewayModelRegistryService(session)
return await service.create_provider_auth(payload=payload, ctx=ctx)
@router.patch("/provider-auth/{provider_auth_id}", response_model=LlmProviderAuthRead)
async def update_provider_auth(
provider_auth_id: UUID,
payload: LlmProviderAuthUpdate,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> LlmProviderAuthRead:
"""Patch a provider auth record and sync gateway config."""
service = GatewayModelRegistryService(session)
return await service.update_provider_auth(
provider_auth_id=provider_auth_id,
payload=payload,
ctx=ctx,
)
@router.delete("/provider-auth/{provider_auth_id}", response_model=OkResponse)
async def delete_provider_auth(
provider_auth_id: UUID,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> OkResponse:
"""Delete a provider auth record and sync gateway config."""
service = GatewayModelRegistryService(session)
await service.delete_provider_auth(provider_auth_id=provider_auth_id, ctx=ctx)
return OkResponse()
@router.get("/models", response_model=list[LlmModelRead])
async def list_models(
gateway_id: UUID | None = GATEWAY_ID_QUERY,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> list[LlmModelRead]:
"""List gateway model catalog entries for the active organization."""
service = GatewayModelRegistryService(session)
return await service.list_models(ctx=ctx, gateway_id=gateway_id)
@router.post("/models", response_model=LlmModelRead)
async def create_model(
payload: LlmModelCreate,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> LlmModelRead:
"""Create a model catalog entry and sync gateway config."""
service = GatewayModelRegistryService(session)
return await service.create_model(payload=payload, ctx=ctx)
@router.patch("/models/{model_id}", response_model=LlmModelRead)
async def update_model(
model_id: UUID,
payload: LlmModelUpdate,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> LlmModelRead:
"""Patch a model catalog entry and sync gateway config."""
service = GatewayModelRegistryService(session)
return await service.update_model(model_id=model_id, payload=payload, ctx=ctx)
@router.delete("/models/{model_id}", response_model=OkResponse)
async def delete_model(
model_id: UUID,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> OkResponse:
"""Delete a model catalog entry and sync gateway config."""
service = GatewayModelRegistryService(session)
await service.delete_model(model_id=model_id, ctx=ctx)
return OkResponse()
@router.post("/gateways/{gateway_id}/sync", response_model=GatewayModelSyncResult)
async def sync_gateway_models(
gateway_id: UUID,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> GatewayModelSyncResult:
"""Push provider auth + model catalog + agent model links to a gateway."""
service = GatewayModelRegistryService(session)
gateway = await service.require_gateway(
gateway_id=gateway_id,
organization_id=ctx.organization.id,
)
return await service.sync_gateway_config(
gateway=gateway,
organization_id=ctx.organization.id,
)
@router.post("/gateways/{gateway_id}/pull", response_model=GatewayModelPullResult)
async def pull_gateway_models(
gateway_id: UUID,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> GatewayModelPullResult:
"""Pull provider auth + model catalog + agent model links from a gateway."""
service = GatewayModelRegistryService(session)
gateway = await service.require_gateway(
gateway_id=gateway_id,
organization_id=ctx.organization.id,
)
return await service.pull_gateway_config(
gateway=gateway,
organization_id=ctx.organization.id,
)

View File

@@ -22,6 +22,7 @@ from app.api.boards import router as boards_router
from app.api.gateway import router as gateway_router
from app.api.gateways import router as gateways_router
from app.api.metrics import router as metrics_router
from app.api.model_registry import router as model_registry_router
from app.api.organizations import router as organizations_router
from app.api.souls_directory import router as souls_directory_router
from app.api.tasks import router as tasks_router
@@ -98,6 +99,7 @@ api_v1.include_router(activity_router)
api_v1.include_router(gateway_router)
api_v1.include_router(gateways_router)
api_v1.include_router(metrics_router)
api_v1.include_router(model_registry_router)
api_v1.include_router(organizations_router)
api_v1.include_router(souls_directory_router)
api_v1.include_router(board_groups_router)

View File

@@ -10,6 +10,7 @@ from app.models.board_memory import BoardMemory
from app.models.board_onboarding import BoardOnboardingSession
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.llm import LlmModel, LlmProviderAuth
from app.models.organization_board_access import OrganizationBoardAccess
from app.models.organization_invite_board_access import OrganizationInviteBoardAccess
from app.models.organization_invites import OrganizationInvite
@@ -31,6 +32,8 @@ __all__ = [
"BoardGroup",
"Board",
"Gateway",
"LlmModel",
"LlmProviderAuth",
"Organization",
"OrganizationMember",
"OrganizationBoardAccess",

View File

@@ -44,5 +44,10 @@ class Agent(QueryModel, table=True):
delete_confirm_token_hash: str | None = Field(default=None, index=True)
last_seen_at: datetime | None = Field(default=None)
is_board_lead: bool = Field(default=False, index=True)
primary_model_id: UUID | None = Field(default=None, foreign_key="llm_models.id", index=True)
fallback_model_ids: list[str] | None = Field(
default=None,
sa_column=Column(JSON),
)
created_at: datetime = Field(default_factory=utcnow)
updated_at: datetime = Field(default_factory=utcnow)

46
backend/app/models/llm.py Normal file
View File

@@ -0,0 +1,46 @@
"""Models for gateway-scoped LLM provider auth and model catalog records."""
from __future__ import annotations
from datetime import datetime
from typing import Any
from uuid import UUID, uuid4
from sqlalchemy import JSON, Column
from sqlmodel import Field
from app.core.time import utcnow
from app.models.base import QueryModel
RUNTIME_ANNOTATION_TYPES = (datetime,)
class LlmProviderAuth(QueryModel, table=True):
"""Provider auth settings to write into a specific gateway config."""
__tablename__ = "llm_provider_auth" # pyright: ignore[reportAssignmentType]
id: UUID = Field(default_factory=uuid4, primary_key=True)
organization_id: UUID = Field(foreign_key="organizations.id", index=True)
gateway_id: UUID = Field(foreign_key="gateways.id", index=True)
provider: str = Field(index=True)
config_path: str
secret: str
created_at: datetime = Field(default_factory=utcnow)
updated_at: datetime = Field(default_factory=utcnow)
class LlmModel(QueryModel, table=True):
"""Gateway model catalog entries available for agent assignment."""
__tablename__ = "llm_models" # pyright: ignore[reportAssignmentType]
id: UUID = Field(default_factory=uuid4, primary_key=True)
organization_id: UUID = Field(foreign_key="organizations.id", index=True)
gateway_id: UUID = Field(foreign_key="gateways.id", index=True)
provider: str = Field(index=True)
model_id: str = Field(index=True)
display_name: str
settings: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
created_at: datetime = Field(default_factory=utcnow)
updated_at: datetime = Field(default_factory=utcnow)

View File

@@ -13,6 +13,16 @@ from app.schemas.board_onboarding import (
)
from app.schemas.boards import BoardCreate, BoardRead, BoardUpdate
from app.schemas.gateways import GatewayCreate, GatewayRead, GatewayUpdate
from app.schemas.llm_models import (
GatewayModelPullResult,
GatewayModelSyncResult,
LlmModelCreate,
LlmModelRead,
LlmModelUpdate,
LlmProviderAuthCreate,
LlmProviderAuthRead,
LlmProviderAuthUpdate,
)
from app.schemas.metrics import DashboardMetrics
from app.schemas.organizations import (
OrganizationActiveUpdate,
@@ -57,6 +67,14 @@ __all__ = [
"GatewayRead",
"GatewayUpdate",
"DashboardMetrics",
"GatewayModelPullResult",
"GatewayModelSyncResult",
"LlmModelCreate",
"LlmModelRead",
"LlmModelUpdate",
"LlmProviderAuthCreate",
"LlmProviderAuthRead",
"LlmProviderAuthUpdate",
"OrganizationActiveUpdate",
"OrganizationCreate",
"OrganizationInviteAccept",

View File

@@ -39,6 +39,27 @@ def _normalize_identity_profile(
return normalized or None
def _normalize_model_ids(
model_ids: object,
) -> list[UUID] | None:
if model_ids is None:
return None
if not isinstance(model_ids, (list, tuple, set)):
raise ValueError("fallback_model_ids must be a list")
normalized: list[UUID] = []
seen: set[UUID] = set()
for raw in model_ids:
candidate = str(raw).strip()
if not candidate:
continue
model_id = UUID(candidate)
if model_id in seen:
continue
seen.add(model_id)
normalized.append(model_id)
return normalized or None
class AgentBase(SQLModel):
"""Common fields shared by agent create/read/update payloads."""
@@ -46,6 +67,8 @@ class AgentBase(SQLModel):
name: NonEmptyStr
status: str = "provisioning"
heartbeat_config: dict[str, Any] | None = None
primary_model_id: UUID | None = None
fallback_model_ids: list[UUID] | None = None
identity_profile: dict[str, Any] | None = None
identity_template: str | None = None
soul_template: str | None = None
@@ -70,6 +93,15 @@ class AgentBase(SQLModel):
"""Normalize identity-profile values into trimmed string mappings."""
return _normalize_identity_profile(value)
@field_validator("fallback_model_ids", mode="before")
@classmethod
def normalize_fallback_model_ids(
cls,
value: object,
) -> list[UUID] | None:
"""Normalize fallback model ids into ordered UUID values."""
return _normalize_model_ids(value)
class AgentCreate(AgentBase):
"""Payload for creating a new agent."""
@@ -83,6 +115,8 @@ class AgentUpdate(SQLModel):
name: NonEmptyStr | None = None
status: str | None = None
heartbeat_config: dict[str, Any] | None = None
primary_model_id: UUID | None = None
fallback_model_ids: list[UUID] | None = None
identity_profile: dict[str, Any] | None = None
identity_template: str | None = None
soul_template: str | None = None
@@ -107,6 +141,15 @@ class AgentUpdate(SQLModel):
"""Normalize identity-profile values into trimmed string mappings."""
return _normalize_identity_profile(value)
@field_validator("fallback_model_ids", mode="before")
@classmethod
def normalize_fallback_model_ids(
cls,
value: object,
) -> list[UUID] | None:
"""Normalize fallback model ids into ordered UUID values."""
return _normalize_model_ids(value)
class AgentRead(AgentBase):
"""Public agent representation returned by the API."""

View File

@@ -0,0 +1,167 @@
"""Schemas for LLM provider auth, model catalog, and gateway sync payloads."""
from __future__ import annotations
from datetime import datetime
from typing import Any
from uuid import UUID
from pydantic import field_validator
from sqlmodel import Field, SQLModel
from app.schemas.common import NonEmptyStr
RUNTIME_ANNOTATION_TYPES = (datetime, UUID, NonEmptyStr)
def _normalize_provider(value: object) -> str | object:
if isinstance(value, str):
normalized = value.strip().lower()
return normalized or value
return value
def _normalize_config_path(value: object) -> str | object:
if isinstance(value, str):
normalized = value.strip()
return normalized or value
return value
def _default_provider_config_path(provider: str) -> str:
return f"providers.{provider}.apiKey"
class LlmProviderAuthBase(SQLModel):
"""Shared provider auth fields."""
gateway_id: UUID
provider: NonEmptyStr
config_path: NonEmptyStr | None = None
@field_validator("provider", mode="before")
@classmethod
def normalize_provider(cls, value: object) -> str | object:
return _normalize_provider(value)
@field_validator("config_path", mode="before")
@classmethod
def normalize_config_path(cls, value: object) -> str | object:
return _normalize_config_path(value)
class LlmProviderAuthCreate(LlmProviderAuthBase):
"""Payload used to create a provider auth record."""
secret: NonEmptyStr
class LlmProviderAuthUpdate(SQLModel):
"""Payload used to patch an existing provider auth record."""
provider: NonEmptyStr | None = None
config_path: NonEmptyStr | None = None
secret: NonEmptyStr | None = None
@field_validator("provider", mode="before")
@classmethod
def normalize_provider(cls, value: object) -> str | object:
return _normalize_provider(value)
@field_validator("config_path", mode="before")
@classmethod
def normalize_config_path(cls, value: object) -> str | object:
return _normalize_config_path(value)
class LlmProviderAuthRead(SQLModel):
"""Public provider auth payload (secret value is never returned)."""
id: UUID
organization_id: UUID
gateway_id: UUID
provider: str
config_path: str
has_secret: bool = True
created_at: datetime
updated_at: datetime
class LlmModelBase(SQLModel):
"""Shared gateway model catalog fields."""
gateway_id: UUID
provider: NonEmptyStr
model_id: NonEmptyStr
display_name: NonEmptyStr
settings: dict[str, Any] | None = None
@field_validator("provider", mode="before")
@classmethod
def normalize_provider(cls, value: object) -> str | object:
return _normalize_provider(value)
class LlmModelCreate(LlmModelBase):
"""Payload used to create a model catalog entry."""
class LlmModelUpdate(SQLModel):
"""Payload used to patch an existing model catalog entry."""
provider: NonEmptyStr | None = None
model_id: NonEmptyStr | None = None
display_name: NonEmptyStr | None = None
settings: dict[str, Any] | None = None
@field_validator("provider", mode="before")
@classmethod
def normalize_provider(cls, value: object) -> str | object:
return _normalize_provider(value)
class LlmModelRead(SQLModel):
"""Public model catalog entry payload."""
id: UUID
organization_id: UUID
gateway_id: UUID
provider: str
model_id: str
display_name: str
settings: dict[str, Any] | None = None
created_at: datetime
updated_at: datetime
class GatewayModelSyncResult(SQLModel):
"""Summary of model/provider config sync operations for a gateway."""
gateway_id: UUID
provider_auth_patched: int
model_catalog_patched: int
agent_models_patched: int
sessions_patched: int
errors: list[str] = Field(default_factory=list)
class GatewayModelPullResult(SQLModel):
"""Summary of model/provider config pull operations for a gateway."""
gateway_id: UUID
provider_auth_imported: int
model_catalog_imported: int
agent_models_imported: int
errors: list[str] = Field(default_factory=list)
__all__ = [
"GatewayModelPullResult",
"GatewayModelSyncResult",
"LlmModelCreate",
"LlmModelRead",
"LlmModelUpdate",
"LlmProviderAuthCreate",
"LlmProviderAuthRead",
"LlmProviderAuthUpdate",
]

File diff suppressed because it is too large Load Diff

View File

@@ -492,10 +492,32 @@ async def _gateway_config_agent_list(
msg = "config.get returned invalid payload"
raise OpenClawGatewayError(msg)
data = cfg.get("config") or cfg.get("parsed") or {}
if not isinstance(data, dict):
msg = "config.get returned invalid config"
raise OpenClawGatewayError(msg)
# Prefer parsed object over raw serialized config to support older gateways.
raw_parsed = cfg.get("parsed")
raw_config = cfg.get("config")
data: dict[str, Any]
def _parse_json_config(raw: str) -> dict[str, Any]:
try:
parsed = json.loads(raw)
except json.JSONDecodeError as exc:
msg = "config.get returned invalid config"
raise OpenClawGatewayError(msg) from exc
if not isinstance(parsed, dict):
msg = "config.get returned invalid config"
raise OpenClawGatewayError(msg)
return parsed
if isinstance(raw_parsed, dict):
data = raw_parsed
elif isinstance(raw_config, dict):
data = raw_config
elif isinstance(raw_parsed, str) and raw_parsed.strip():
data = _parse_json_config(raw_parsed)
elif isinstance(raw_config, str) and raw_config.strip():
data = _parse_json_config(raw_config)
else:
data = {}
agents_section = data.get("agents") or {}
agents_list = agents_section.get("list") or []

View File

@@ -32,6 +32,7 @@ from app.models.agents import Agent
from app.models.board_memory import BoardMemory
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.llm import LlmModel
from app.models.organizations import Organization
from app.models.tasks import Task
from app.schemas.agents import (
@@ -72,6 +73,7 @@ from app.services.openclaw.internal.session_keys import (
board_agent_session_key,
board_lead_session_key,
)
from app.services.openclaw.model_registry_service import GatewayModelRegistryService
from app.services.openclaw.policies import OpenClawAuthorizationPolicy
from app.services.openclaw.provisioning import (
OpenClawGatewayControlPlane,
@@ -959,6 +961,71 @@ class AgentLifecycleService(OpenClawDBService):
detail="An agent with this name already exists in this gateway workspace.",
)
@staticmethod
def _normalized_fallback_ids(value: object) -> list[UUID]:
if not isinstance(value, list):
return []
normalized: list[UUID] = []
seen: set[UUID] = set()
for raw in value:
candidate = str(raw).strip()
if not candidate:
continue
try:
model_id = UUID(candidate)
except ValueError:
continue
if model_id in seen:
continue
seen.add(model_id)
normalized.append(model_id)
return normalized
async def normalize_agent_model_assignments(
self,
*,
gateway_id: UUID,
primary_model_id: UUID | None,
fallback_model_ids: list[UUID],
) -> tuple[UUID | None, list[str] | None]:
if primary_model_id is None and not fallback_model_ids:
return None, None
candidate_ids: list[UUID] = []
if primary_model_id is not None:
candidate_ids.append(primary_model_id)
candidate_ids.extend(fallback_model_ids)
unique_ids = list(dict.fromkeys(candidate_ids))
statement = (
select(LlmModel.id)
.where(col(LlmModel.gateway_id) == gateway_id)
.where(col(LlmModel.id).in_(unique_ids))
)
with self.session.no_autoflush:
valid_ids = set(await self.session.exec(statement))
missing = [value for value in unique_ids if value not in valid_ids]
if missing:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Model assignment includes model ids not in the agent gateway catalog.",
)
filtered_fallback: list[str] = []
for fallback_id in fallback_model_ids:
if primary_model_id is not None and fallback_id == primary_model_id:
continue
value = str(fallback_id)
if value in filtered_fallback:
continue
filtered_fallback.append(value)
return primary_model_id, (filtered_fallback or None)
async def sync_gateway_agent_models(self, *, gateway: Gateway) -> None:
await GatewayModelRegistryService(self.session).sync_gateway_config(
gateway=gateway,
organization_id=gateway.organization_id,
)
async def persist_new_agent(
self,
*,
@@ -1177,6 +1244,13 @@ class AgentLifecycleService(OpenClawDBService):
detail="Board gateway_id is required",
)
updates["gateway_id"] = board.gateway_id
if "primary_model_id" in updates:
primary_value = updates["primary_model_id"]
if primary_value is not None and not isinstance(primary_value, UUID):
updates["primary_model_id"] = UUID(str(primary_value))
if "fallback_model_ids" in updates:
normalized_fallback = self._normalized_fallback_ids(updates["fallback_model_ids"])
updates["fallback_model_ids"] = [str(model_id) for model_id in normalized_fallback] or None
for key, value in updates.items():
setattr(agent, key, value)
@@ -1192,6 +1266,19 @@ class AgentLifecycleService(OpenClawDBService):
detail="Board gateway_id is required",
)
agent.gateway_id = board.gateway_id
primary_model_id = agent.primary_model_id
if primary_model_id is not None and not isinstance(primary_model_id, UUID):
primary_model_id = UUID(str(primary_model_id))
fallback_model_ids = self._normalized_fallback_ids(agent.fallback_model_ids)
normalized_primary, normalized_fallback = await self.normalize_agent_model_assignments(
gateway_id=agent.gateway_id,
primary_model_id=primary_model_id if isinstance(primary_model_id, UUID) else None,
fallback_model_ids=fallback_model_ids,
)
agent.primary_model_id = normalized_primary
agent.fallback_model_ids = normalized_fallback
agent.updated_at = utcnow()
if agent.heartbeat_config is None:
agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy()
@@ -1487,6 +1574,17 @@ class AgentLifecycleService(OpenClawDBService):
gateway, _client_config = await self.require_gateway(board)
data = payload.model_dump()
data["gateway_id"] = gateway.id
primary_model_id = data.get("primary_model_id")
if primary_model_id is not None and not isinstance(primary_model_id, UUID):
primary_model_id = UUID(str(primary_model_id))
fallback_model_ids = self._normalized_fallback_ids(data.get("fallback_model_ids"))
normalized_primary, normalized_fallback = await self.normalize_agent_model_assignments(
gateway_id=gateway.id,
primary_model_id=primary_model_id if isinstance(primary_model_id, UUID) else None,
fallback_model_ids=fallback_model_ids,
)
data["primary_model_id"] = normalized_primary
data["fallback_model_ids"] = normalized_fallback
requested_name = (data.get("name") or "").strip()
await self.ensure_unique_agent_name(
board=board,
@@ -1502,6 +1600,8 @@ class AgentLifecycleService(OpenClawDBService):
user=actor.user if actor.actor_type == "user" else None,
force_bootstrap=False,
)
if agent.primary_model_id is not None or agent.fallback_model_ids:
await self.sync_gateway_agent_models(gateway=gateway)
self.logger.info("agent.create.success agent_id=%s board_id=%s", agent.id, board.id)
return self.to_agent_read(self.with_computed_status(agent))
@@ -1535,6 +1635,7 @@ class AgentLifecycleService(OpenClawDBService):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
await self.require_agent_access(agent=agent, ctx=options.context, write=True)
updates = payload.model_dump(exclude_unset=True)
sync_model_assignments = "primary_model_id" in updates or "fallback_model_ids" in updates
make_main = updates.pop("is_gateway_main", None)
await self.validate_agent_update_inputs(
ctx=options.context,
@@ -1568,6 +1669,8 @@ class AgentLifecycleService(OpenClawDBService):
agent=agent,
request=provision_request,
)
if sync_model_assignments or agent.primary_model_id is not None or agent.fallback_model_ids:
await self.sync_gateway_agent_models(gateway=target.gateway)
self.logger.info("agent.update.success agent_id=%s", agent.id)
return self.to_agent_read(self.with_computed_status(agent))

View File

@@ -0,0 +1,285 @@
"""add llm model registry
Revision ID: 9a3fb1158c2d
Revises: f4d2b649e93a
Create Date: 2026-02-11 21:00:00.000000
"""
from __future__ import annotations
import sqlalchemy as sa
import sqlmodel
from alembic import op
# revision identifiers, used by Alembic.
revision = "9a3fb1158c2d"
down_revision = "f4d2b649e93a"
branch_labels = None
depends_on = None
def _has_index(inspector: sa.Inspector, table: str, index_name: str) -> bool:
return any(item.get("name") == index_name for item in inspector.get_indexes(table))
def _has_unique(
inspector: sa.Inspector,
table: str,
*,
name: str | None = None,
columns: tuple[str, ...] | None = None,
) -> bool:
unique_constraints = inspector.get_unique_constraints(table)
for item in unique_constraints:
if name and item.get("name") == name:
return True
if columns and tuple(item.get("column_names") or ()) == columns:
return True
return False
def _column_names(inspector: sa.Inspector, table: str) -> set[str]:
return {item["name"] for item in inspector.get_columns(table)}
def _has_foreign_key(
inspector: sa.Inspector,
table: str,
*,
constrained_columns: tuple[str, ...],
referred_table: str,
) -> bool:
for item in inspector.get_foreign_keys(table):
if tuple(item.get("constrained_columns") or ()) != constrained_columns:
continue
if item.get("referred_table") != referred_table:
continue
return True
return False
def upgrade() -> None:
bind = op.get_bind()
inspector = sa.inspect(bind)
if not inspector.has_table("llm_provider_auth"):
op.create_table(
"llm_provider_auth",
sa.Column("id", sa.Uuid(), nullable=False),
sa.Column("organization_id", sa.Uuid(), nullable=False),
sa.Column("gateway_id", sa.Uuid(), nullable=False),
sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("config_path", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("secret", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(["gateway_id"], ["gateways.id"]),
sa.ForeignKeyConstraint(["organization_id"], ["organizations.id"]),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"gateway_id",
"provider",
"config_path",
name="uq_llm_provider_auth_gateway_provider_path",
),
)
inspector = sa.inspect(bind)
else:
existing_columns = _column_names(inspector, "llm_provider_auth")
if "config_path" not in existing_columns:
op.add_column(
"llm_provider_auth",
sa.Column(
"config_path",
sqlmodel.sql.sqltypes.AutoString(),
nullable=False,
server_default="providers.openai.apiKey",
),
)
op.alter_column("llm_provider_auth", "config_path", server_default=None)
inspector = sa.inspect(bind)
if not _has_unique(
inspector,
"llm_provider_auth",
name="uq_llm_provider_auth_gateway_provider_path",
columns=("gateway_id", "provider", "config_path"),
):
op.create_unique_constraint(
"uq_llm_provider_auth_gateway_provider_path",
"llm_provider_auth",
["gateway_id", "provider", "config_path"],
)
inspector = sa.inspect(bind)
if not _has_index(inspector, "llm_provider_auth", op.f("ix_llm_provider_auth_gateway_id")):
op.create_index(
op.f("ix_llm_provider_auth_gateway_id"),
"llm_provider_auth",
["gateway_id"],
unique=False,
)
inspector = sa.inspect(bind)
if not _has_index(
inspector,
"llm_provider_auth",
op.f("ix_llm_provider_auth_organization_id"),
):
op.create_index(
op.f("ix_llm_provider_auth_organization_id"),
"llm_provider_auth",
["organization_id"],
unique=False,
)
inspector = sa.inspect(bind)
if not _has_index(inspector, "llm_provider_auth", op.f("ix_llm_provider_auth_provider")):
op.create_index(
op.f("ix_llm_provider_auth_provider"),
"llm_provider_auth",
["provider"],
unique=False,
)
inspector = sa.inspect(bind)
if not inspector.has_table("llm_models"):
op.create_table(
"llm_models",
sa.Column("id", sa.Uuid(), nullable=False),
sa.Column("organization_id", sa.Uuid(), nullable=False),
sa.Column("gateway_id", sa.Uuid(), nullable=False),
sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("model_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("display_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("settings", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(["gateway_id"], ["gateways.id"]),
sa.ForeignKeyConstraint(["organization_id"], ["organizations.id"]),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("gateway_id", "model_id", name="uq_llm_models_gateway_model_id"),
)
inspector = sa.inspect(bind)
elif not _has_unique(
inspector,
"llm_models",
name="uq_llm_models_gateway_model_id",
columns=("gateway_id", "model_id"),
):
op.create_unique_constraint(
"uq_llm_models_gateway_model_id",
"llm_models",
["gateway_id", "model_id"],
)
inspector = sa.inspect(bind)
if not _has_index(inspector, "llm_models", op.f("ix_llm_models_gateway_id")):
op.create_index(
op.f("ix_llm_models_gateway_id"),
"llm_models",
["gateway_id"],
unique=False,
)
inspector = sa.inspect(bind)
if not _has_index(inspector, "llm_models", op.f("ix_llm_models_model_id")):
op.create_index(
op.f("ix_llm_models_model_id"),
"llm_models",
["model_id"],
unique=False,
)
inspector = sa.inspect(bind)
if not _has_index(inspector, "llm_models", op.f("ix_llm_models_organization_id")):
op.create_index(
op.f("ix_llm_models_organization_id"),
"llm_models",
["organization_id"],
unique=False,
)
inspector = sa.inspect(bind)
if not _has_index(inspector, "llm_models", op.f("ix_llm_models_provider")):
op.create_index(
op.f("ix_llm_models_provider"),
"llm_models",
["provider"],
unique=False,
)
inspector = sa.inspect(bind)
agent_columns = _column_names(inspector, "agents")
if "primary_model_id" not in agent_columns:
op.add_column("agents", sa.Column("primary_model_id", sa.Uuid(), nullable=True))
inspector = sa.inspect(bind)
if "fallback_model_ids" not in agent_columns:
op.add_column("agents", sa.Column("fallback_model_ids", sa.JSON(), nullable=True))
inspector = sa.inspect(bind)
if not _has_index(inspector, "agents", op.f("ix_agents_primary_model_id")):
op.create_index(op.f("ix_agents_primary_model_id"), "agents", ["primary_model_id"], unique=False)
inspector = sa.inspect(bind)
if not _has_foreign_key(
inspector,
"agents",
constrained_columns=("primary_model_id",),
referred_table="llm_models",
):
op.create_foreign_key(
"fk_agents_primary_model_id_llm_models",
"agents",
"llm_models",
["primary_model_id"],
["id"],
)
def downgrade() -> None:
bind = op.get_bind()
inspector = sa.inspect(bind)
if inspector.has_table("agents"):
for fk in inspector.get_foreign_keys("agents"):
if tuple(fk.get("constrained_columns") or ()) != ("primary_model_id",):
continue
if fk.get("referred_table") != "llm_models":
continue
fk_name = fk.get("name")
if fk_name:
op.drop_constraint(fk_name, "agents", type_="foreignkey")
inspector = sa.inspect(bind)
if _has_index(inspector, "agents", op.f("ix_agents_primary_model_id")):
op.drop_index(op.f("ix_agents_primary_model_id"), table_name="agents")
agent_columns = _column_names(inspector, "agents")
if "fallback_model_ids" in agent_columns:
op.drop_column("agents", "fallback_model_ids")
if "primary_model_id" in agent_columns:
op.drop_column("agents", "primary_model_id")
inspector = sa.inspect(bind)
if inspector.has_table("llm_models"):
if _has_index(inspector, "llm_models", op.f("ix_llm_models_provider")):
op.drop_index(op.f("ix_llm_models_provider"), table_name="llm_models")
if _has_index(inspector, "llm_models", op.f("ix_llm_models_organization_id")):
op.drop_index(op.f("ix_llm_models_organization_id"), table_name="llm_models")
if _has_index(inspector, "llm_models", op.f("ix_llm_models_model_id")):
op.drop_index(op.f("ix_llm_models_model_id"), table_name="llm_models")
if _has_index(inspector, "llm_models", op.f("ix_llm_models_gateway_id")):
op.drop_index(op.f("ix_llm_models_gateway_id"), table_name="llm_models")
op.drop_table("llm_models")
inspector = sa.inspect(bind)
if inspector.has_table("llm_provider_auth"):
if _has_index(inspector, "llm_provider_auth", op.f("ix_llm_provider_auth_provider")):
op.drop_index(op.f("ix_llm_provider_auth_provider"), table_name="llm_provider_auth")
if _has_index(
inspector,
"llm_provider_auth",
op.f("ix_llm_provider_auth_organization_id"),
):
op.drop_index(
op.f("ix_llm_provider_auth_organization_id"),
table_name="llm_provider_auth",
)
if _has_index(inspector, "llm_provider_auth", op.f("ix_llm_provider_auth_gateway_id")):
op.drop_index(
op.f("ix_llm_provider_auth_gateway_id"),
table_name="llm_provider_auth",
)
op.drop_table("llm_provider_auth")

View File

@@ -0,0 +1,105 @@
# ruff: noqa: S101
"""Regression tests for agent model-assignment update normalization."""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from uuid import UUID, uuid4
import pytest
from unittest.mock import AsyncMock
from app.services.openclaw.provisioning_db import AgentLifecycleService
class _NoAutoflush:
def __init__(self, session: "_SessionStub") -> None:
self._session = session
def __enter__(self) -> None:
self._session.in_no_autoflush = True
return None
def __exit__(self, exc_type, exc, tb) -> bool:
self._session.in_no_autoflush = False
return False
class _SessionStub:
def __init__(self, valid_ids: set[UUID]) -> None:
self.valid_ids = valid_ids
self.in_no_autoflush = False
self.commits = 0
@property
def no_autoflush(self) -> _NoAutoflush:
return _NoAutoflush(self)
async def exec(self, _statement: Any) -> list[UUID]:
if not self.in_no_autoflush:
raise AssertionError("Expected normalize query to run under no_autoflush.")
return list(self.valid_ids)
def add(self, _model: Any) -> None:
return None
async def commit(self) -> None:
self.commits += 1
async def refresh(self, _model: Any) -> None:
return None
@dataclass
class _AgentStub:
gateway_id: UUID
board_id: UUID | None = None
is_board_lead: bool = False
openclaw_session_id: str | None = None
primary_model_id: UUID | None = None
fallback_model_ids: list[str] | None = None
updated_at: datetime | None = None
heartbeat_config: dict[str, Any] | None = field(
default_factory=lambda: {"interval_seconds": 5},
)
@pytest.mark.asyncio
async def test_normalize_agent_model_assignments_uses_no_autoflush() -> None:
primary = uuid4()
fallback = uuid4()
session = _SessionStub({primary, fallback})
service = AgentLifecycleService(session) # type: ignore[arg-type]
normalized_primary, normalized_fallback = await service.normalize_agent_model_assignments(
gateway_id=uuid4(),
primary_model_id=primary,
fallback_model_ids=[primary, fallback],
)
assert normalized_primary == primary
assert normalized_fallback == [str(fallback)]
@pytest.mark.asyncio
async def test_apply_agent_update_mutations_coerces_fallback_ids_to_strings(monkeypatch) -> None:
primary = uuid4()
fallback = uuid4()
session = _SessionStub({primary, fallback})
service = AgentLifecycleService(session) # type: ignore[arg-type]
monkeypatch.setattr(service, "get_main_agent_gateway", AsyncMock(return_value=None))
agent = _AgentStub(gateway_id=uuid4())
updates: dict[str, Any] = {
"primary_model_id": primary,
"fallback_model_ids": [primary, fallback, fallback],
}
await service.apply_agent_update_mutations(agent=agent, updates=updates, make_main=None) # type: ignore[arg-type]
assert updates["fallback_model_ids"] == [str(primary), str(fallback)]
assert agent.primary_model_id == primary
assert agent.fallback_model_ids == [str(fallback)]
assert session.commits == 1

View File

@@ -0,0 +1,28 @@
# ruff: noqa: S101
"""Tests for agent model-assignment schema normalization."""
from __future__ import annotations
from uuid import uuid4
import pytest
from pydantic import ValidationError
from app.schemas.agents import AgentCreate
def test_agent_create_normalizes_fallback_model_ids() -> None:
model_a = uuid4()
model_b = uuid4()
payload = AgentCreate(
name="Worker",
fallback_model_ids=[str(model_a), str(model_b), str(model_a)],
)
assert payload.fallback_model_ids == [model_a, model_b]
def test_agent_create_rejects_non_list_fallback_model_ids() -> None:
with pytest.raises(ValidationError):
AgentCreate(name="Worker", fallback_model_ids="not-a-list")

View File

@@ -0,0 +1,110 @@
# ruff: noqa: S101
"""Tests for gateway model-registry pull helper normalization."""
from __future__ import annotations
from app.services.openclaw.model_registry_service import (
_extract_config_data,
_get_nested_path,
_infer_provider_for_model,
_model_config,
_model_settings,
_normalize_provider,
_parse_agent_model_value,
)
def test_get_nested_path_resolves_existing_value() -> None:
source = {"providers": {"openai": {"apiKey": "sk-test"}}}
assert _get_nested_path(source, ["providers", "openai", "apiKey"]) == "sk-test"
assert _get_nested_path(source, ["providers", "anthropic", "apiKey"]) is None
def test_normalize_provider_trims_and_lowercases() -> None:
assert _normalize_provider(" OpenAI ") == "openai"
assert _normalize_provider("") is None
assert _normalize_provider(123) is None
def test_infer_provider_for_model_prefers_prefix_delimiter() -> None:
assert _infer_provider_for_model("openai/gpt-5") == "openai"
assert _infer_provider_for_model("anthropic:claude-sonnet") == "anthropic"
assert _infer_provider_for_model("gpt-5") == "unknown"
def test_model_settings_only_accepts_dict_payloads() -> None:
settings = _model_settings({"provider": "openai", "temperature": 0.2})
assert settings == {"provider": "openai", "temperature": 0.2}
assert _model_settings("not-a-dict") is None
def test_parse_agent_model_value_normalizes_primary_and_fallbacks() -> None:
primary, fallback = _parse_agent_model_value(
{
"primary": " openai/gpt-5 ",
"fallbacks": [
"openai/gpt-4.1",
"openai/gpt-5",
"openai/gpt-4.1",
" ",
123,
],
},
)
assert primary == "openai/gpt-5"
assert fallback == ["openai/gpt-4.1"]
def test_parse_agent_model_value_accepts_legacy_fallback_key() -> None:
primary, fallback = _parse_agent_model_value(
{
"primary": "openai/gpt-5",
"fallback": ["openai/gpt-4.1", "openai/gpt-4.1"],
},
)
assert primary == "openai/gpt-5"
assert fallback == ["openai/gpt-4.1"]
def test_parse_agent_model_value_accepts_string_primary() -> None:
primary, fallback = _parse_agent_model_value(" openai/gpt-5 ")
assert primary == "openai/gpt-5"
assert fallback == []
def test_model_config_uses_fallbacks_key() -> None:
assert _model_config("openai/gpt-5", ["openai/gpt-4.1"]) == {
"primary": "openai/gpt-5",
"fallbacks": ["openai/gpt-4.1"],
}
def test_extract_config_data_prefers_parsed_when_config_is_raw_string() -> None:
config_data, base_hash = _extract_config_data(
{
"config": '{"agents":{"list":[{"id":"a1"}]}}',
"parsed": {"agents": {"list": [{"id": "a1"}]}},
"hash": "abc123",
},
)
assert isinstance(config_data, dict)
assert config_data.get("agents") == {"list": [{"id": "a1"}]}
assert base_hash == "abc123"
def test_extract_config_data_parses_json_string_when_parsed_absent() -> None:
config_data, base_hash = _extract_config_data(
{
"config": '{"providers":{"openai":{"apiKey":"sk-test"}}}',
"hash": "def456",
},
)
assert config_data.get("providers") == {"openai": {"apiKey": "sk-test"}}
assert base_hash == "def456"