diff --git a/backend/app/api/model_registry.py b/backend/app/api/model_registry.py new file mode 100644 index 00000000..ecd1274f --- /dev/null +++ b/backend/app/api/model_registry.py @@ -0,0 +1,165 @@ +"""API routes for gateway model registry and provider-auth management.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from uuid import UUID + +from fastapi import APIRouter, Depends, Query + +from app.api.deps import require_org_admin +from app.db.session import get_session +from app.schemas.common import OkResponse +from app.schemas.llm_models import ( + GatewayModelPullResult, + GatewayModelSyncResult, + LlmModelCreate, + LlmModelRead, + LlmModelUpdate, + LlmProviderAuthCreate, + LlmProviderAuthRead, + LlmProviderAuthUpdate, +) +from app.services.openclaw.model_registry_service import GatewayModelRegistryService +from app.services.organizations import OrganizationContext + +if TYPE_CHECKING: + from sqlmodel.ext.asyncio.session import AsyncSession + +router = APIRouter(prefix="/model-registry", tags=["model-registry"]) + +SESSION_DEP = Depends(get_session) +ORG_ADMIN_DEP = Depends(require_org_admin) +GATEWAY_ID_QUERY = Query(default=None) + + +@router.get("/provider-auth", response_model=list[LlmProviderAuthRead]) +async def list_provider_auth( + gateway_id: UUID | None = GATEWAY_ID_QUERY, + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, +) -> list[LlmProviderAuthRead]: + """List provider auth records for the active organization.""" + service = GatewayModelRegistryService(session) + return await service.list_provider_auth(ctx=ctx, gateway_id=gateway_id) + + +@router.post("/provider-auth", response_model=LlmProviderAuthRead) +async def create_provider_auth( + payload: LlmProviderAuthCreate, + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, +) -> LlmProviderAuthRead: + """Create a provider auth record and sync gateway config.""" + service = GatewayModelRegistryService(session) + return await service.create_provider_auth(payload=payload, ctx=ctx) + + +@router.patch("/provider-auth/{provider_auth_id}", response_model=LlmProviderAuthRead) +async def update_provider_auth( + provider_auth_id: UUID, + payload: LlmProviderAuthUpdate, + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, +) -> LlmProviderAuthRead: + """Patch a provider auth record and sync gateway config.""" + service = GatewayModelRegistryService(session) + return await service.update_provider_auth( + provider_auth_id=provider_auth_id, + payload=payload, + ctx=ctx, + ) + + +@router.delete("/provider-auth/{provider_auth_id}", response_model=OkResponse) +async def delete_provider_auth( + provider_auth_id: UUID, + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, +) -> OkResponse: + """Delete a provider auth record and sync gateway config.""" + service = GatewayModelRegistryService(session) + await service.delete_provider_auth(provider_auth_id=provider_auth_id, ctx=ctx) + return OkResponse() + + +@router.get("/models", response_model=list[LlmModelRead]) +async def list_models( + gateway_id: UUID | None = GATEWAY_ID_QUERY, + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, +) -> list[LlmModelRead]: + """List gateway model catalog entries for the active organization.""" + service = GatewayModelRegistryService(session) + return await service.list_models(ctx=ctx, gateway_id=gateway_id) + + +@router.post("/models", response_model=LlmModelRead) +async def create_model( + payload: LlmModelCreate, + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, +) -> LlmModelRead: + """Create a model catalog entry and sync gateway config.""" + service = GatewayModelRegistryService(session) + return await service.create_model(payload=payload, ctx=ctx) + + +@router.patch("/models/{model_id}", response_model=LlmModelRead) +async def update_model( + model_id: UUID, + payload: LlmModelUpdate, + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, +) -> LlmModelRead: + """Patch a model catalog entry and sync gateway config.""" + service = GatewayModelRegistryService(session) + return await service.update_model(model_id=model_id, payload=payload, ctx=ctx) + + +@router.delete("/models/{model_id}", response_model=OkResponse) +async def delete_model( + model_id: UUID, + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, +) -> OkResponse: + """Delete a model catalog entry and sync gateway config.""" + service = GatewayModelRegistryService(session) + await service.delete_model(model_id=model_id, ctx=ctx) + return OkResponse() + + +@router.post("/gateways/{gateway_id}/sync", response_model=GatewayModelSyncResult) +async def sync_gateway_models( + gateway_id: UUID, + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, +) -> GatewayModelSyncResult: + """Push provider auth + model catalog + agent model links to a gateway.""" + service = GatewayModelRegistryService(session) + gateway = await service.require_gateway( + gateway_id=gateway_id, + organization_id=ctx.organization.id, + ) + return await service.sync_gateway_config( + gateway=gateway, + organization_id=ctx.organization.id, + ) + + +@router.post("/gateways/{gateway_id}/pull", response_model=GatewayModelPullResult) +async def pull_gateway_models( + gateway_id: UUID, + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, +) -> GatewayModelPullResult: + """Pull provider auth + model catalog + agent model links from a gateway.""" + service = GatewayModelRegistryService(session) + gateway = await service.require_gateway( + gateway_id=gateway_id, + organization_id=ctx.organization.id, + ) + return await service.pull_gateway_config( + gateway=gateway, + organization_id=ctx.organization.id, + ) diff --git a/backend/app/main.py b/backend/app/main.py index ebe0bba6..d1963e90 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -22,6 +22,7 @@ from app.api.boards import router as boards_router from app.api.gateway import router as gateway_router from app.api.gateways import router as gateways_router from app.api.metrics import router as metrics_router +from app.api.model_registry import router as model_registry_router from app.api.organizations import router as organizations_router from app.api.souls_directory import router as souls_directory_router from app.api.tasks import router as tasks_router @@ -98,6 +99,7 @@ api_v1.include_router(activity_router) api_v1.include_router(gateway_router) api_v1.include_router(gateways_router) api_v1.include_router(metrics_router) +api_v1.include_router(model_registry_router) api_v1.include_router(organizations_router) api_v1.include_router(souls_directory_router) api_v1.include_router(board_groups_router) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index c697c186..4db22647 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -10,6 +10,7 @@ from app.models.board_memory import BoardMemory from app.models.board_onboarding import BoardOnboardingSession from app.models.boards import Board from app.models.gateways import Gateway +from app.models.llm import LlmModel, LlmProviderAuth from app.models.organization_board_access import OrganizationBoardAccess from app.models.organization_invite_board_access import OrganizationInviteBoardAccess from app.models.organization_invites import OrganizationInvite @@ -31,6 +32,8 @@ __all__ = [ "BoardGroup", "Board", "Gateway", + "LlmModel", + "LlmProviderAuth", "Organization", "OrganizationMember", "OrganizationBoardAccess", diff --git a/backend/app/models/agents.py b/backend/app/models/agents.py index 1648e98f..c9c9142e 100644 --- a/backend/app/models/agents.py +++ b/backend/app/models/agents.py @@ -44,5 +44,10 @@ class Agent(QueryModel, table=True): delete_confirm_token_hash: str | None = Field(default=None, index=True) last_seen_at: datetime | None = Field(default=None) is_board_lead: bool = Field(default=False, index=True) + primary_model_id: UUID | None = Field(default=None, foreign_key="llm_models.id", index=True) + fallback_model_ids: list[str] | None = Field( + default=None, + sa_column=Column(JSON), + ) created_at: datetime = Field(default_factory=utcnow) updated_at: datetime = Field(default_factory=utcnow) diff --git a/backend/app/models/llm.py b/backend/app/models/llm.py new file mode 100644 index 00000000..03d1f3df --- /dev/null +++ b/backend/app/models/llm.py @@ -0,0 +1,46 @@ +"""Models for gateway-scoped LLM provider auth and model catalog records.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any +from uuid import UUID, uuid4 + +from sqlalchemy import JSON, Column +from sqlmodel import Field + +from app.core.time import utcnow +from app.models.base import QueryModel + +RUNTIME_ANNOTATION_TYPES = (datetime,) + + +class LlmProviderAuth(QueryModel, table=True): + """Provider auth settings to write into a specific gateway config.""" + + __tablename__ = "llm_provider_auth" # pyright: ignore[reportAssignmentType] + + id: UUID = Field(default_factory=uuid4, primary_key=True) + organization_id: UUID = Field(foreign_key="organizations.id", index=True) + gateway_id: UUID = Field(foreign_key="gateways.id", index=True) + provider: str = Field(index=True) + config_path: str + secret: str + created_at: datetime = Field(default_factory=utcnow) + updated_at: datetime = Field(default_factory=utcnow) + + +class LlmModel(QueryModel, table=True): + """Gateway model catalog entries available for agent assignment.""" + + __tablename__ = "llm_models" # pyright: ignore[reportAssignmentType] + + id: UUID = Field(default_factory=uuid4, primary_key=True) + organization_id: UUID = Field(foreign_key="organizations.id", index=True) + gateway_id: UUID = Field(foreign_key="gateways.id", index=True) + provider: str = Field(index=True) + model_id: str = Field(index=True) + display_name: str + settings: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON)) + created_at: datetime = Field(default_factory=utcnow) + updated_at: datetime = Field(default_factory=utcnow) diff --git a/backend/app/schemas/__init__.py b/backend/app/schemas/__init__.py index e4fa060d..a91d80a8 100644 --- a/backend/app/schemas/__init__.py +++ b/backend/app/schemas/__init__.py @@ -13,6 +13,16 @@ from app.schemas.board_onboarding import ( ) from app.schemas.boards import BoardCreate, BoardRead, BoardUpdate from app.schemas.gateways import GatewayCreate, GatewayRead, GatewayUpdate +from app.schemas.llm_models import ( + GatewayModelPullResult, + GatewayModelSyncResult, + LlmModelCreate, + LlmModelRead, + LlmModelUpdate, + LlmProviderAuthCreate, + LlmProviderAuthRead, + LlmProviderAuthUpdate, +) from app.schemas.metrics import DashboardMetrics from app.schemas.organizations import ( OrganizationActiveUpdate, @@ -57,6 +67,14 @@ __all__ = [ "GatewayRead", "GatewayUpdate", "DashboardMetrics", + "GatewayModelPullResult", + "GatewayModelSyncResult", + "LlmModelCreate", + "LlmModelRead", + "LlmModelUpdate", + "LlmProviderAuthCreate", + "LlmProviderAuthRead", + "LlmProviderAuthUpdate", "OrganizationActiveUpdate", "OrganizationCreate", "OrganizationInviteAccept", diff --git a/backend/app/schemas/agents.py b/backend/app/schemas/agents.py index eee2eb03..0e68856e 100644 --- a/backend/app/schemas/agents.py +++ b/backend/app/schemas/agents.py @@ -39,6 +39,27 @@ def _normalize_identity_profile( return normalized or None +def _normalize_model_ids( + model_ids: object, +) -> list[UUID] | None: + if model_ids is None: + return None + if not isinstance(model_ids, (list, tuple, set)): + raise ValueError("fallback_model_ids must be a list") + normalized: list[UUID] = [] + seen: set[UUID] = set() + for raw in model_ids: + candidate = str(raw).strip() + if not candidate: + continue + model_id = UUID(candidate) + if model_id in seen: + continue + seen.add(model_id) + normalized.append(model_id) + return normalized or None + + class AgentBase(SQLModel): """Common fields shared by agent create/read/update payloads.""" @@ -46,6 +67,8 @@ class AgentBase(SQLModel): name: NonEmptyStr status: str = "provisioning" heartbeat_config: dict[str, Any] | None = None + primary_model_id: UUID | None = None + fallback_model_ids: list[UUID] | None = None identity_profile: dict[str, Any] | None = None identity_template: str | None = None soul_template: str | None = None @@ -70,6 +93,15 @@ class AgentBase(SQLModel): """Normalize identity-profile values into trimmed string mappings.""" return _normalize_identity_profile(value) + @field_validator("fallback_model_ids", mode="before") + @classmethod + def normalize_fallback_model_ids( + cls, + value: object, + ) -> list[UUID] | None: + """Normalize fallback model ids into ordered UUID values.""" + return _normalize_model_ids(value) + class AgentCreate(AgentBase): """Payload for creating a new agent.""" @@ -83,6 +115,8 @@ class AgentUpdate(SQLModel): name: NonEmptyStr | None = None status: str | None = None heartbeat_config: dict[str, Any] | None = None + primary_model_id: UUID | None = None + fallback_model_ids: list[UUID] | None = None identity_profile: dict[str, Any] | None = None identity_template: str | None = None soul_template: str | None = None @@ -107,6 +141,15 @@ class AgentUpdate(SQLModel): """Normalize identity-profile values into trimmed string mappings.""" return _normalize_identity_profile(value) + @field_validator("fallback_model_ids", mode="before") + @classmethod + def normalize_fallback_model_ids( + cls, + value: object, + ) -> list[UUID] | None: + """Normalize fallback model ids into ordered UUID values.""" + return _normalize_model_ids(value) + class AgentRead(AgentBase): """Public agent representation returned by the API.""" diff --git a/backend/app/schemas/llm_models.py b/backend/app/schemas/llm_models.py new file mode 100644 index 00000000..c9aa9508 --- /dev/null +++ b/backend/app/schemas/llm_models.py @@ -0,0 +1,167 @@ +"""Schemas for LLM provider auth, model catalog, and gateway sync payloads.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any +from uuid import UUID + +from pydantic import field_validator +from sqlmodel import Field, SQLModel + +from app.schemas.common import NonEmptyStr + +RUNTIME_ANNOTATION_TYPES = (datetime, UUID, NonEmptyStr) + + +def _normalize_provider(value: object) -> str | object: + if isinstance(value, str): + normalized = value.strip().lower() + return normalized or value + return value + + +def _normalize_config_path(value: object) -> str | object: + if isinstance(value, str): + normalized = value.strip() + return normalized or value + return value + + +def _default_provider_config_path(provider: str) -> str: + return f"providers.{provider}.apiKey" + + +class LlmProviderAuthBase(SQLModel): + """Shared provider auth fields.""" + + gateway_id: UUID + provider: NonEmptyStr + config_path: NonEmptyStr | None = None + + @field_validator("provider", mode="before") + @classmethod + def normalize_provider(cls, value: object) -> str | object: + return _normalize_provider(value) + + @field_validator("config_path", mode="before") + @classmethod + def normalize_config_path(cls, value: object) -> str | object: + return _normalize_config_path(value) + + +class LlmProviderAuthCreate(LlmProviderAuthBase): + """Payload used to create a provider auth record.""" + + secret: NonEmptyStr + + +class LlmProviderAuthUpdate(SQLModel): + """Payload used to patch an existing provider auth record.""" + + provider: NonEmptyStr | None = None + config_path: NonEmptyStr | None = None + secret: NonEmptyStr | None = None + + @field_validator("provider", mode="before") + @classmethod + def normalize_provider(cls, value: object) -> str | object: + return _normalize_provider(value) + + @field_validator("config_path", mode="before") + @classmethod + def normalize_config_path(cls, value: object) -> str | object: + return _normalize_config_path(value) + + +class LlmProviderAuthRead(SQLModel): + """Public provider auth payload (secret value is never returned).""" + + id: UUID + organization_id: UUID + gateway_id: UUID + provider: str + config_path: str + has_secret: bool = True + created_at: datetime + updated_at: datetime + + +class LlmModelBase(SQLModel): + """Shared gateway model catalog fields.""" + + gateway_id: UUID + provider: NonEmptyStr + model_id: NonEmptyStr + display_name: NonEmptyStr + settings: dict[str, Any] | None = None + + @field_validator("provider", mode="before") + @classmethod + def normalize_provider(cls, value: object) -> str | object: + return _normalize_provider(value) + + +class LlmModelCreate(LlmModelBase): + """Payload used to create a model catalog entry.""" + + +class LlmModelUpdate(SQLModel): + """Payload used to patch an existing model catalog entry.""" + + provider: NonEmptyStr | None = None + model_id: NonEmptyStr | None = None + display_name: NonEmptyStr | None = None + settings: dict[str, Any] | None = None + + @field_validator("provider", mode="before") + @classmethod + def normalize_provider(cls, value: object) -> str | object: + return _normalize_provider(value) + + +class LlmModelRead(SQLModel): + """Public model catalog entry payload.""" + + id: UUID + organization_id: UUID + gateway_id: UUID + provider: str + model_id: str + display_name: str + settings: dict[str, Any] | None = None + created_at: datetime + updated_at: datetime + + +class GatewayModelSyncResult(SQLModel): + """Summary of model/provider config sync operations for a gateway.""" + + gateway_id: UUID + provider_auth_patched: int + model_catalog_patched: int + agent_models_patched: int + sessions_patched: int + errors: list[str] = Field(default_factory=list) + + +class GatewayModelPullResult(SQLModel): + """Summary of model/provider config pull operations for a gateway.""" + + gateway_id: UUID + provider_auth_imported: int + model_catalog_imported: int + agent_models_imported: int + errors: list[str] = Field(default_factory=list) + + +__all__ = [ + "GatewayModelPullResult", + "GatewayModelSyncResult", + "LlmModelCreate", + "LlmModelRead", + "LlmModelUpdate", + "LlmProviderAuthCreate", + "LlmProviderAuthRead", + "LlmProviderAuthUpdate", +] diff --git a/backend/app/services/openclaw/model_registry_service.py b/backend/app/services/openclaw/model_registry_service.py new file mode 100644 index 00000000..cf8a3729 --- /dev/null +++ b/backend/app/services/openclaw/model_registry_service.py @@ -0,0 +1,1002 @@ +"""Gateway-scoped model registry and provider-auth synchronization service.""" + +from __future__ import annotations + +import json +from typing import Any +from uuid import UUID + +from fastapi import HTTPException, status +from sqlalchemy.exc import IntegrityError +from sqlmodel import col, select + +from app.core.time import utcnow +from app.models.agents import Agent +from app.models.gateways import Gateway +from app.models.llm import LlmModel, LlmProviderAuth +from app.schemas.llm_models import ( + GatewayModelPullResult, + GatewayModelSyncResult, + LlmModelCreate, + LlmModelRead, + LlmModelUpdate, + LlmProviderAuthCreate, + LlmProviderAuthRead, + LlmProviderAuthUpdate, +) +from app.services.openclaw.db_service import OpenClawDBService +from app.services.openclaw.gateway_rpc import GatewayConfig as GatewayClientConfig +from app.services.openclaw.gateway_rpc import OpenClawGatewayError, openclaw_call +from app.services.openclaw.internal.agent_key import agent_key as board_agent_key +from app.services.openclaw.provisioning import _heartbeat_config, _workspace_path +from app.services.openclaw.shared import GatewayAgentIdentity +from app.services.organizations import OrganizationContext + + +def _set_nested_path(target: dict[str, object], path: list[str], value: object) -> None: + node: dict[str, object] = target + for key in path[:-1]: + next_node = node.get(key) + if not isinstance(next_node, dict): + next_node = {} + node[key] = next_node + node = next_node + node[path[-1]] = value + + +def _get_nested_path(source: dict[str, object], path: list[str]) -> object | None: + node: object = source + for key in path: + if not isinstance(node, dict): + return None + node = node.get(key) + return node + + +def _normalize_provider(value: object) -> str | None: + if not isinstance(value, str): + return None + provider = value.strip().lower() + return provider or None + + +def _infer_provider_for_model(model_id: str) -> str: + candidate = model_id.strip() + if not candidate: + return "unknown" + for delimiter in ("/", ":"): + if delimiter in candidate: + prefix = candidate.split(delimiter, 1)[0].strip().lower() + if prefix: + return prefix + return "unknown" + + +def _model_settings(raw_value: object) -> dict[str, Any] | None: + if not isinstance(raw_value, dict): + return None + return dict(raw_value) + + +def _parse_agent_model_value(raw_value: object) -> tuple[str | None, list[str]]: + primary_value: str | None + if isinstance(raw_value, str): + primary_value = raw_value.strip() or None + return primary_value, [] + if not isinstance(raw_value, dict): + return None, [] + primary_raw = raw_value.get("primary") + primary_value = primary_raw.strip() if isinstance(primary_raw, str) else None + if not primary_value: + primary_value = None + fallback_raw = raw_value.get("fallbacks") + if fallback_raw is None: + fallback_raw = raw_value.get("fallback") + fallback_values: list[str] = [] + if isinstance(fallback_raw, list): + for item in fallback_raw: + if not isinstance(item, str): + continue + value = item.strip() + if not value: + continue + if primary_value and value == primary_value: + continue + if value in fallback_values: + continue + fallback_values.append(value) + return primary_value, fallback_values + + +def _parse_model_uuid(value: object) -> UUID | None: + if value is None: + return None + candidate = str(value).strip() + if not candidate: + return None + try: + return UUID(candidate) + except ValueError: + return None + + +def _model_config(primary: str | None, fallback: list[str]) -> dict[str, object] | None: + if not primary and not fallback: + return None + value: dict[str, object] = {} + if primary: + value["primary"] = primary + if fallback: + value["fallbacks"] = fallback + return value + + +def _json_to_dict(raw: object) -> dict[str, object] | None: + if isinstance(raw, dict): + return raw + if not isinstance(raw, str): + return None + candidate = raw.strip() + if not candidate: + return {} + try: + parsed = json.loads(candidate) + except json.JSONDecodeError: + return None + if isinstance(parsed, dict): + return parsed + return None + + +def _extract_config_data(cfg: dict[str, object]) -> tuple[dict[str, object], str | None]: + # Prefer parsed config over raw serialized content when both are present. + parsed_config = _json_to_dict(cfg.get("parsed")) + if parsed_config is not None: + return parsed_config, cfg.get("hash") if isinstance(cfg.get("hash"), str) else None + raw_config = _json_to_dict(cfg.get("config")) + if raw_config is not None: + return raw_config, cfg.get("hash") if isinstance(cfg.get("hash"), str) else None + # Some gateways return the parsed config object at top-level. + if any(key in cfg for key in ("agents", "providers", "channels")): + return cfg, cfg.get("hash") if isinstance(cfg.get("hash"), str) else None + raise OpenClawGatewayError("config.get returned invalid config") + + +def _constraint_name_from_error(exc: IntegrityError) -> str | None: + diag = getattr(getattr(exc, "orig", None), "diag", None) + if diag is None: + return None + constraint = getattr(diag, "constraint_name", None) + if isinstance(constraint, str) and constraint: + return constraint + return None + + +def _is_constraint_violation(exc: IntegrityError, constraint_name: str) -> bool: + if _constraint_name_from_error(exc) == constraint_name: + return True + return constraint_name in str(getattr(exc, "orig", exc)) + + +class GatewayModelRegistryService(OpenClawDBService): + """Manage provider auth + model catalogs and sync them into gateway config.""" + + MODEL_UNIQUE_CONSTRAINT = "uq_llm_models_gateway_model_id" + PROVIDER_AUTH_UNIQUE_CONSTRAINT = "uq_llm_provider_auth_gateway_provider_path" + + async def require_gateway( + self, + *, + gateway_id: UUID, + organization_id: UUID, + ) -> Gateway: + gateway = ( + await Gateway.objects.by_id(gateway_id) + .filter(col(Gateway.organization_id) == organization_id) + .first(self.session) + ) + if gateway is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found") + return gateway + + async def list_provider_auth( + self, + *, + ctx: OrganizationContext, + gateway_id: UUID | None, + ) -> list[LlmProviderAuthRead]: + statement = select(LlmProviderAuth).where( + col(LlmProviderAuth.organization_id) == ctx.organization.id, + ) + if gateway_id is not None: + statement = statement.where(col(LlmProviderAuth.gateway_id) == gateway_id) + statement = statement.order_by( + col(LlmProviderAuth.provider).asc(), + col(LlmProviderAuth.created_at).desc(), + ) + rows = list(await self.session.exec(statement)) + return [self._to_provider_auth_read(item) for item in rows] + + async def create_provider_auth( + self, + *, + payload: LlmProviderAuthCreate, + ctx: OrganizationContext, + ) -> LlmProviderAuthRead: + gateway = await self.require_gateway( + gateway_id=payload.gateway_id, + organization_id=ctx.organization.id, + ) + config_path = payload.config_path or f"providers.{payload.provider}.apiKey" + existing = ( + await self.session.exec( + select(LlmProviderAuth) + .where(col(LlmProviderAuth.organization_id) == ctx.organization.id) + .where(col(LlmProviderAuth.gateway_id) == gateway.id) + .where(col(LlmProviderAuth.provider) == payload.provider) + .where(col(LlmProviderAuth.config_path) == config_path), + ) + ).first() + if existing is not None: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Provider auth already exists for this gateway/provider/path.", + ) + record = LlmProviderAuth( + organization_id=ctx.organization.id, + gateway_id=gateway.id, + provider=payload.provider, + config_path=config_path, + secret=payload.secret, + ) + self.session.add(record) + try: + await self.session.commit() + except IntegrityError as exc: + await self.session.rollback() + if _is_constraint_violation(exc, self.PROVIDER_AUTH_UNIQUE_CONSTRAINT): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Provider auth already exists for this gateway/provider/path.", + ) from exc + raise + await self.session.refresh(record) + await self.sync_gateway_config(gateway=gateway, organization_id=ctx.organization.id) + return self._to_provider_auth_read(record) + + async def update_provider_auth( + self, + *, + provider_auth_id: UUID, + payload: LlmProviderAuthUpdate, + ctx: OrganizationContext, + ) -> LlmProviderAuthRead: + record = await LlmProviderAuth.objects.by_id(provider_auth_id).first(self.session) + if record is None or record.organization_id != ctx.organization.id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Provider auth not found" + ) + updates = payload.model_dump(exclude_unset=True) + if "provider" in updates: + updates["provider"] = str(updates["provider"]).strip().lower() + if "config_path" in updates and updates["config_path"] is not None: + updates["config_path"] = str(updates["config_path"]).strip() + if ( + "provider" in updates + and "config_path" not in updates + and (record.config_path == f"providers.{record.provider}.apiKey") + ): + updates["config_path"] = f"providers.{updates['provider']}.apiKey" + candidate_provider = str(updates.get("provider", record.provider)).strip().lower() + candidate_path = str(updates.get("config_path", record.config_path)).strip() + duplicate = ( + await self.session.exec( + select(LlmProviderAuth.id) + .where(col(LlmProviderAuth.organization_id) == ctx.organization.id) + .where(col(LlmProviderAuth.gateway_id) == record.gateway_id) + .where(col(LlmProviderAuth.provider) == candidate_provider) + .where(col(LlmProviderAuth.config_path) == candidate_path) + .where(col(LlmProviderAuth.id) != record.id), + ) + ).first() + if duplicate is not None: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Provider auth already exists for this gateway/provider/path.", + ) + for key, value in updates.items(): + setattr(record, key, value) + record.updated_at = utcnow() + self.session.add(record) + try: + await self.session.commit() + except IntegrityError as exc: + await self.session.rollback() + if _is_constraint_violation(exc, self.PROVIDER_AUTH_UNIQUE_CONSTRAINT): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Provider auth already exists for this gateway/provider/path.", + ) from exc + raise + await self.session.refresh(record) + await self.sync_gateway_config( + gateway=await self.require_gateway( + gateway_id=record.gateway_id, + organization_id=ctx.organization.id, + ), + organization_id=ctx.organization.id, + ) + return self._to_provider_auth_read(record) + + async def delete_provider_auth( + self, + *, + provider_auth_id: UUID, + ctx: OrganizationContext, + ) -> None: + record = await LlmProviderAuth.objects.by_id(provider_auth_id).first(self.session) + if record is None or record.organization_id != ctx.organization.id: + return + gateway_id = record.gateway_id + await self.session.delete(record) + await self.session.commit() + await self.sync_gateway_config( + gateway=await self.require_gateway( + gateway_id=gateway_id, + organization_id=ctx.organization.id, + ), + organization_id=ctx.organization.id, + ) + + async def list_models( + self, + *, + ctx: OrganizationContext, + gateway_id: UUID | None, + ) -> list[LlmModelRead]: + statement = select(LlmModel).where(col(LlmModel.organization_id) == ctx.organization.id) + if gateway_id is not None: + statement = statement.where(col(LlmModel.gateway_id) == gateway_id) + statement = statement.order_by(col(LlmModel.provider).asc(), col(LlmModel.model_id).asc()) + rows = list(await self.session.exec(statement)) + return [LlmModelRead.model_validate(item, from_attributes=True) for item in rows] + + async def create_model( + self, + *, + payload: LlmModelCreate, + ctx: OrganizationContext, + ) -> LlmModelRead: + gateway = await self.require_gateway( + gateway_id=payload.gateway_id, + organization_id=ctx.organization.id, + ) + existing = ( + await self.session.exec( + select(LlmModel.id) + .where(col(LlmModel.organization_id) == ctx.organization.id) + .where(col(LlmModel.gateway_id) == gateway.id) + .where(col(LlmModel.model_id) == payload.model_id), + ) + ).first() + if existing is not None: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Model already exists in this gateway catalog.", + ) + model = LlmModel( + organization_id=ctx.organization.id, + gateway_id=gateway.id, + provider=payload.provider, + model_id=payload.model_id, + display_name=payload.display_name, + settings=payload.settings, + ) + self.session.add(model) + try: + await self.session.commit() + except IntegrityError as exc: + await self.session.rollback() + if _is_constraint_violation(exc, self.MODEL_UNIQUE_CONSTRAINT): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Model already exists in this gateway catalog.", + ) from exc + raise + await self.session.refresh(model) + await self.sync_gateway_config(gateway=gateway, organization_id=ctx.organization.id) + return LlmModelRead.model_validate(model, from_attributes=True) + + async def update_model( + self, + *, + model_id: UUID, + payload: LlmModelUpdate, + ctx: OrganizationContext, + ) -> LlmModelRead: + model = await LlmModel.objects.by_id(model_id).first(self.session) + if model is None or model.organization_id != ctx.organization.id: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Model not found") + updates = payload.model_dump(exclude_unset=True) + if "provider" in updates and updates["provider"] is not None: + updates["provider"] = str(updates["provider"]).strip().lower() + if "model_id" in updates and updates["model_id"] is not None: + candidate_model_id = str(updates["model_id"]).strip() + updates["model_id"] = candidate_model_id + duplicate = ( + await self.session.exec( + select(LlmModel.id) + .where(col(LlmModel.organization_id) == ctx.organization.id) + .where(col(LlmModel.gateway_id) == model.gateway_id) + .where(col(LlmModel.model_id) == candidate_model_id) + .where(col(LlmModel.id) != model.id), + ) + ).first() + if duplicate is not None: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Model already exists in this gateway catalog.", + ) + for key, value in updates.items(): + setattr(model, key, value) + model.updated_at = utcnow() + self.session.add(model) + try: + await self.session.commit() + except IntegrityError as exc: + await self.session.rollback() + if _is_constraint_violation(exc, self.MODEL_UNIQUE_CONSTRAINT): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Model already exists in this gateway catalog.", + ) from exc + raise + await self.session.refresh(model) + await self.sync_gateway_config( + gateway=await self.require_gateway( + gateway_id=model.gateway_id, + organization_id=ctx.organization.id, + ), + organization_id=ctx.organization.id, + ) + return LlmModelRead.model_validate(model, from_attributes=True) + + async def delete_model( + self, + *, + model_id: UUID, + ctx: OrganizationContext, + ) -> None: + model = await LlmModel.objects.by_id(model_id).first(self.session) + if model is None or model.organization_id != ctx.organization.id: + return + gateway_id = model.gateway_id + removed_id = model.id + await self.session.delete(model) + await self.session.commit() + + agents = await Agent.objects.filter_by(gateway_id=gateway_id).all(self.session) + changed = False + for agent in agents: + agent_changed = False + if agent.primary_model_id == removed_id: + agent.primary_model_id = None + agent_changed = True + raw_fallback = agent.fallback_model_ids or [] + if not isinstance(raw_fallback, list): + continue + filtered = [] + for item in raw_fallback: + parsed = _parse_model_uuid(item) + if parsed is None or parsed == removed_id: + continue + filtered.append(str(parsed)) + if filtered != raw_fallback: + agent.fallback_model_ids = filtered or None + agent_changed = True + if agent_changed: + agent.updated_at = utcnow() + self.session.add(agent) + changed = True + if changed: + await self.session.commit() + + await self.sync_gateway_config( + gateway=await self.require_gateway( + gateway_id=gateway_id, + organization_id=ctx.organization.id, + ), + organization_id=ctx.organization.id, + ) + + async def pull_gateway_config( + self, + *, + gateway: Gateway, + organization_id: UUID, + ) -> GatewayModelPullResult: + """Import provider auth, model catalog, and agent model links from gateway config.""" + if gateway.organization_id != organization_id: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found") + if not gateway.url: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Gateway URL is not configured.", + ) + + config = GatewayClientConfig(url=gateway.url, token=gateway.token) + result = GatewayModelPullResult( + gateway_id=gateway.id, + provider_auth_imported=0, + model_catalog_imported=0, + agent_models_imported=0, + errors=[], + ) + + try: + cfg = await openclaw_call("config.get", config=config) + if not isinstance(cfg, dict): + raise OpenClawGatewayError("config.get returned invalid payload") + config_data, _ = _extract_config_data(cfg) + except OpenClawGatewayError as exc: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"Gateway model pull failed: {exc}", + ) from exc + + has_db_changes = False + has_pending_models = False + + provider_auth_rows = await LlmProviderAuth.objects.filter_by( + organization_id=organization_id, + gateway_id=gateway.id, + ).all(self.session) + provider_auth_by_key: dict[tuple[str, str], LlmProviderAuth] = { + (item.provider, item.config_path): item for item in provider_auth_rows + } + + for provider_auth_item in provider_auth_rows: + path = [part.strip() for part in provider_auth_item.config_path.split(".") if part.strip()] + if not path: + continue + pulled_secret_value = _get_nested_path(config_data, path) + if not isinstance(pulled_secret_value, str): + continue + secret = pulled_secret_value.strip() + if not secret or secret == provider_auth_item.secret: + continue + provider_auth_item.secret = secret + provider_auth_item.updated_at = utcnow() + self.session.add(provider_auth_item) + result.provider_auth_imported += 1 + has_db_changes = True + + providers_data = config_data.get("providers") + if isinstance(providers_data, dict): + for raw_provider, raw_provider_config in providers_data.items(): + provider = _normalize_provider(raw_provider) + if not provider: + continue + pulled_secret: str | None = None + config_path: str | None = None + if isinstance(raw_provider_config, dict): + raw_api_key = raw_provider_config.get("apiKey") + if isinstance(raw_api_key, str): + api_key = raw_api_key.strip() + if api_key: + pulled_secret = api_key + config_path = f"providers.{provider}.apiKey" + elif isinstance(raw_provider_config, str): + secret = raw_provider_config.strip() + if secret: + pulled_secret = secret + config_path = f"providers.{provider}" + if not pulled_secret or not config_path: + continue + provider_key = (provider, config_path) + existing = provider_auth_by_key.get(provider_key) + if existing is None: + record = LlmProviderAuth( + organization_id=organization_id, + gateway_id=gateway.id, + provider=provider, + config_path=config_path, + secret=pulled_secret, + ) + self.session.add(record) + provider_auth_by_key[provider_key] = record + result.provider_auth_imported += 1 + has_db_changes = True + continue + if existing.secret == pulled_secret: + continue + existing.secret = pulled_secret + existing.updated_at = utcnow() + self.session.add(existing) + result.provider_auth_imported += 1 + has_db_changes = True + + existing_models = await LlmModel.objects.filter_by( + organization_id=organization_id, + gateway_id=gateway.id, + ).all(self.session) + models_by_model_id: dict[str, LlmModel] = {item.model_id: item for item in existing_models} + + catalog_models_data: dict[str, object] = {} + agents_data = config_data.get("agents") + if isinstance(agents_data, dict): + defaults_data = agents_data.get("defaults") + if isinstance(defaults_data, dict): + raw_models = defaults_data.get("models") + if isinstance(raw_models, dict): + catalog_models_data = raw_models + + for raw_model_id, raw_model_config in catalog_models_data.items(): + if not isinstance(raw_model_id, str): + result.errors.append("Skipped one catalog model: model id is not a string.") + continue + model_id = raw_model_id.strip() + if not model_id: + result.errors.append("Skipped one catalog model: model id is empty.") + continue + + settings = _model_settings(raw_model_config) + provider_from_settings = ( + _normalize_provider(settings.get("provider")) if settings is not None else None + ) + provider = provider_from_settings or _infer_provider_for_model(model_id) + display_name = model_id + if settings: + for display_key in ("display_name", "displayName", "name"): + candidate = settings.get(display_key) + if isinstance(candidate, str) and candidate.strip(): + display_name = candidate.strip() + break + + existing_model = models_by_model_id.get(model_id) + if existing_model is None: + model = LlmModel( + organization_id=organization_id, + gateway_id=gateway.id, + provider=provider, + model_id=model_id, + display_name=display_name, + settings=settings, + ) + self.session.add(model) + models_by_model_id[model_id] = model + result.model_catalog_imported += 1 + has_db_changes = True + has_pending_models = True + continue + + model_changed = False + if existing_model.provider != provider: + existing_model.provider = provider + model_changed = True + if existing_model.display_name != display_name: + existing_model.display_name = display_name + model_changed = True + if existing_model.settings != settings: + existing_model.settings = settings + model_changed = True + if model_changed: + existing_model.updated_at = utcnow() + self.session.add(existing_model) + result.model_catalog_imported += 1 + has_db_changes = True + + agents = await Agent.objects.filter_by(gateway_id=gateway.id).all(self.session) + agents_by_openclaw_id: dict[str, Agent] = {} + for agent in agents: + if agent.board_id is None: + agent_id = GatewayAgentIdentity.openclaw_agent_id(gateway) + else: + agent_id = board_agent_key(agent) + agents_by_openclaw_id[agent_id] = agent + + raw_agents_list: list[object] = [] + if isinstance(agents_data, dict): + raw_agent_values = agents_data.get("list") or [] + if isinstance(raw_agent_values, list): + raw_agents_list = raw_agent_values + + assignments_by_agent: dict[UUID, tuple[str | None, list[str]]] = {} + assignment_model_ids: set[str] = set() + for raw_entry in raw_agents_list: + if not isinstance(raw_entry, dict): + continue + raw_agent_id = raw_entry.get("id") + if not isinstance(raw_agent_id, str): + continue + agent_id = raw_agent_id.strip() + if not agent_id: + continue + resolved_agent = agents_by_openclaw_id.get(agent_id) + if resolved_agent is None: + continue + if "model" not in raw_entry: + continue + primary_model_id, fallback_model_ids = _parse_agent_model_value(raw_entry.get("model")) + assignments_by_agent[resolved_agent.id] = (primary_model_id, fallback_model_ids) + if primary_model_id: + assignment_model_ids.add(primary_model_id) + assignment_model_ids.update(fallback_model_ids) + + for model_id in assignment_model_ids: + if model_id in models_by_model_id: + continue + model = LlmModel( + organization_id=organization_id, + gateway_id=gateway.id, + provider=_infer_provider_for_model(model_id), + model_id=model_id, + display_name=model_id, + settings=None, + ) + self.session.add(model) + models_by_model_id[model_id] = model + result.model_catalog_imported += 1 + has_db_changes = True + has_pending_models = True + + if has_pending_models: + await self.session.flush() + + changed_agents = 0 + for agent in agents: + model_assignment = assignments_by_agent.get(agent.id) + if model_assignment is None: + continue + primary_model_key, fallback_model_keys = model_assignment + + primary_model_uuid: UUID | None = None + if primary_model_key: + primary_model = models_by_model_id.get(primary_model_key) + if primary_model is None: + result.errors.append( + f"Skipped primary model '{primary_model_key}' for agent '{agent.name}'.", + ) + else: + primary_model_uuid = primary_model.id + + fallback_values: list[str] = [] + for model_key in fallback_model_keys: + resolved_model = models_by_model_id.get(model_key) + if resolved_model is None: + result.errors.append( + f"Skipped fallback model '{model_key}' for agent '{agent.name}'.", + ) + continue + if primary_model_uuid is not None and resolved_model.id == primary_model_uuid: + continue + fallback_id = str(resolved_model.id) + if fallback_id in fallback_values: + continue + fallback_values.append(fallback_id) + normalized_fallback_model_ids: list[str] | None = fallback_values or None + + current_fallback_values: list[str] = [] + for raw_value in agent.fallback_model_ids or []: + parsed = _parse_model_uuid(raw_value) + if parsed is None: + continue + value = str(parsed) + if value in current_fallback_values: + continue + current_fallback_values.append(value) + current_fallback_model_ids = current_fallback_values or None + + if ( + agent.primary_model_id == primary_model_uuid + and current_fallback_model_ids == normalized_fallback_model_ids + ): + continue + agent.primary_model_id = primary_model_uuid + agent.fallback_model_ids = normalized_fallback_model_ids + agent.updated_at = utcnow() + self.session.add(agent) + changed_agents += 1 + has_db_changes = True + + result.agent_models_imported = changed_agents + if has_db_changes: + await self.session.commit() + return result + + async def sync_gateway_config( + self, + *, + gateway: Gateway, + organization_id: UUID, + ) -> GatewayModelSyncResult: + if gateway.organization_id != organization_id: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found") + if not gateway.url: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Gateway URL is not configured.", + ) + + config = GatewayClientConfig(url=gateway.url, token=gateway.token) + provider_auth = await LlmProviderAuth.objects.filter_by( + organization_id=organization_id, + gateway_id=gateway.id, + ).all(self.session) + model_catalog = await LlmModel.objects.filter_by( + organization_id=organization_id, + gateway_id=gateway.id, + ).all(self.session) + agents = await Agent.objects.filter_by(gateway_id=gateway.id).all(self.session) + model_id_map = {model.id: model.model_id for model in model_catalog} + default_primary_model: str | None = None + + result = GatewayModelSyncResult( + gateway_id=gateway.id, + provider_auth_patched=0, + model_catalog_patched=0, + agent_models_patched=0, + sessions_patched=0, + errors=[], + ) + + try: + cfg = await openclaw_call("config.get", config=config) + if not isinstance(cfg, dict): + raise OpenClawGatewayError("config.get returned invalid payload") + config_data, base_hash = _extract_config_data(cfg) + + patch: dict[str, object] = {} + for provider_auth_item in provider_auth: + path = [ + part.strip() + for part in provider_auth_item.config_path.split(".") + if part.strip() + ] + if not path: + result.errors.append( + f"Skipped provider auth {provider_auth_item.id}: config_path is empty.", + ) + continue + _set_nested_path(patch, path, provider_auth_item.secret) + result.provider_auth_patched += 1 + + if model_catalog: + models_patch: dict[str, object] = {} + for model in model_catalog: + value = dict(model.settings or {}) + # Gateway model objects reject provider metadata; store provider in DB only. + value.pop("provider", None) + models_patch[model.model_id] = value + _set_nested_path(patch, ["agents", "defaults", "models"], models_patch) + result.model_catalog_patched = len(model_catalog) + + existing_primary = None + agents_section = config_data.get("agents") + if isinstance(agents_section, dict): + defaults_section = agents_section.get("defaults") + if isinstance(defaults_section, dict): + model_section = defaults_section.get("model") + if isinstance(model_section, dict): + candidate = model_section.get("primary") + if isinstance(candidate, str) and candidate: + existing_primary = candidate + if existing_primary in models_patch: + default_primary_model = existing_primary + else: + first_model = model_catalog[0].model_id + _set_nested_path( + patch, + ["agents", "defaults", "model", "primary"], + first_model, + ) + default_primary_model = first_model + + raw_agents_list: list[object] = [] + agents_section = config_data.get("agents") + if isinstance(agents_section, dict): + raw_agents_list = agents_section.get("list") or [] + if not isinstance(raw_agents_list, list): + raw_agents_list = [] + + existing_entries: dict[str, dict[str, object]] = {} + passthrough_entries: list[object] = [] + for raw_entry in raw_agents_list: + if not isinstance(raw_entry, dict): + passthrough_entries.append(raw_entry) + continue + entry_id = raw_entry.get("id") + if isinstance(entry_id, str) and entry_id: + existing_entries[entry_id] = dict(raw_entry) + else: + passthrough_entries.append(raw_entry) + + updated_entries: list[object] = [] + for agent in agents: + if agent.board_id is None: + agent_id = GatewayAgentIdentity.openclaw_agent_id(gateway) + else: + agent_id = board_agent_key(agent) + agent_entry = existing_entries.pop(agent_id, None) + if agent_entry is None: + agent_entry = { + "id": agent_id, + "workspace": _workspace_path(agent, gateway.workspace_root), + "heartbeat": _heartbeat_config(agent), + } + + primary_model_id = ( + model_id_map.get(agent.primary_model_id) if agent.primary_model_id else None + ) + fallback_values: list[str] = [] + for raw_value in agent.fallback_model_ids or []: + parsed = _parse_model_uuid(raw_value) + if parsed is None: + continue + mapped = model_id_map.get(parsed) + if not mapped: + continue + if primary_model_id and mapped == primary_model_id: + continue + if mapped in fallback_values: + continue + fallback_values.append(mapped) + model_value = _model_config(primary_model_id, fallback_values) + if model_value is None: + agent_entry.pop("model", None) + else: + agent_entry["model"] = model_value + result.agent_models_patched += 1 + updated_entries.append(agent_entry) + + for remaining in existing_entries.values(): + updated_entries.append(remaining) + updated_entries.extend(passthrough_entries) + _set_nested_path(patch, ["agents", "list"], updated_entries) + + params: dict[str, object] = {"raw": json.dumps(patch)} + if base_hash: + params["baseHash"] = base_hash + await openclaw_call("config.patch", params, config=config) + except OpenClawGatewayError as exc: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"Gateway model sync failed: {exc}", + ) from exc + + for agent in agents: + primary = model_id_map.get(agent.primary_model_id) if agent.primary_model_id else None + if not primary: + primary = default_primary_model + if not primary: + continue + session_key = (agent.openclaw_session_id or "").strip() + if not session_key: + continue + try: + await openclaw_call( + "sessions.patch", + { + "key": session_key, + "label": agent.name, + "model": primary, + }, + config=config, + ) + result.sessions_patched += 1 + except OpenClawGatewayError as exc: + result.errors.append(f"{agent.name}: {exc}") + return result + + @staticmethod + def _to_provider_auth_read(record: LlmProviderAuth) -> LlmProviderAuthRead: + return LlmProviderAuthRead( + id=record.id, + organization_id=record.organization_id, + gateway_id=record.gateway_id, + provider=record.provider, + config_path=record.config_path, + has_secret=bool(record.secret), + created_at=record.created_at, + updated_at=record.updated_at, + ) diff --git a/backend/app/services/openclaw/provisioning.py b/backend/app/services/openclaw/provisioning.py index 96becad5..b8265ffa 100644 --- a/backend/app/services/openclaw/provisioning.py +++ b/backend/app/services/openclaw/provisioning.py @@ -492,10 +492,32 @@ async def _gateway_config_agent_list( msg = "config.get returned invalid payload" raise OpenClawGatewayError(msg) - data = cfg.get("config") or cfg.get("parsed") or {} - if not isinstance(data, dict): - msg = "config.get returned invalid config" - raise OpenClawGatewayError(msg) + # Prefer parsed object over raw serialized config to support older gateways. + raw_parsed = cfg.get("parsed") + raw_config = cfg.get("config") + data: dict[str, Any] + + def _parse_json_config(raw: str) -> dict[str, Any]: + try: + parsed = json.loads(raw) + except json.JSONDecodeError as exc: + msg = "config.get returned invalid config" + raise OpenClawGatewayError(msg) from exc + if not isinstance(parsed, dict): + msg = "config.get returned invalid config" + raise OpenClawGatewayError(msg) + return parsed + + if isinstance(raw_parsed, dict): + data = raw_parsed + elif isinstance(raw_config, dict): + data = raw_config + elif isinstance(raw_parsed, str) and raw_parsed.strip(): + data = _parse_json_config(raw_parsed) + elif isinstance(raw_config, str) and raw_config.strip(): + data = _parse_json_config(raw_config) + else: + data = {} agents_section = data.get("agents") or {} agents_list = agents_section.get("list") or [] diff --git a/backend/app/services/openclaw/provisioning_db.py b/backend/app/services/openclaw/provisioning_db.py index a7ab59ba..f8a3f9b6 100644 --- a/backend/app/services/openclaw/provisioning_db.py +++ b/backend/app/services/openclaw/provisioning_db.py @@ -32,6 +32,7 @@ from app.models.agents import Agent from app.models.board_memory import BoardMemory from app.models.boards import Board from app.models.gateways import Gateway +from app.models.llm import LlmModel from app.models.organizations import Organization from app.models.tasks import Task from app.schemas.agents import ( @@ -72,6 +73,7 @@ from app.services.openclaw.internal.session_keys import ( board_agent_session_key, board_lead_session_key, ) +from app.services.openclaw.model_registry_service import GatewayModelRegistryService from app.services.openclaw.policies import OpenClawAuthorizationPolicy from app.services.openclaw.provisioning import ( OpenClawGatewayControlPlane, @@ -959,6 +961,71 @@ class AgentLifecycleService(OpenClawDBService): detail="An agent with this name already exists in this gateway workspace.", ) + @staticmethod + def _normalized_fallback_ids(value: object) -> list[UUID]: + if not isinstance(value, list): + return [] + normalized: list[UUID] = [] + seen: set[UUID] = set() + for raw in value: + candidate = str(raw).strip() + if not candidate: + continue + try: + model_id = UUID(candidate) + except ValueError: + continue + if model_id in seen: + continue + seen.add(model_id) + normalized.append(model_id) + return normalized + + async def normalize_agent_model_assignments( + self, + *, + gateway_id: UUID, + primary_model_id: UUID | None, + fallback_model_ids: list[UUID], + ) -> tuple[UUID | None, list[str] | None]: + if primary_model_id is None and not fallback_model_ids: + return None, None + + candidate_ids: list[UUID] = [] + if primary_model_id is not None: + candidate_ids.append(primary_model_id) + candidate_ids.extend(fallback_model_ids) + unique_ids = list(dict.fromkeys(candidate_ids)) + statement = ( + select(LlmModel.id) + .where(col(LlmModel.gateway_id) == gateway_id) + .where(col(LlmModel.id).in_(unique_ids)) + ) + with self.session.no_autoflush: + valid_ids = set(await self.session.exec(statement)) + missing = [value for value in unique_ids if value not in valid_ids] + if missing: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Model assignment includes model ids not in the agent gateway catalog.", + ) + + filtered_fallback: list[str] = [] + for fallback_id in fallback_model_ids: + if primary_model_id is not None and fallback_id == primary_model_id: + continue + value = str(fallback_id) + if value in filtered_fallback: + continue + filtered_fallback.append(value) + return primary_model_id, (filtered_fallback or None) + + async def sync_gateway_agent_models(self, *, gateway: Gateway) -> None: + await GatewayModelRegistryService(self.session).sync_gateway_config( + gateway=gateway, + organization_id=gateway.organization_id, + ) + async def persist_new_agent( self, *, @@ -1177,6 +1244,13 @@ class AgentLifecycleService(OpenClawDBService): detail="Board gateway_id is required", ) updates["gateway_id"] = board.gateway_id + if "primary_model_id" in updates: + primary_value = updates["primary_model_id"] + if primary_value is not None and not isinstance(primary_value, UUID): + updates["primary_model_id"] = UUID(str(primary_value)) + if "fallback_model_ids" in updates: + normalized_fallback = self._normalized_fallback_ids(updates["fallback_model_ids"]) + updates["fallback_model_ids"] = [str(model_id) for model_id in normalized_fallback] or None for key, value in updates.items(): setattr(agent, key, value) @@ -1192,6 +1266,19 @@ class AgentLifecycleService(OpenClawDBService): detail="Board gateway_id is required", ) agent.gateway_id = board.gateway_id + + primary_model_id = agent.primary_model_id + if primary_model_id is not None and not isinstance(primary_model_id, UUID): + primary_model_id = UUID(str(primary_model_id)) + fallback_model_ids = self._normalized_fallback_ids(agent.fallback_model_ids) + normalized_primary, normalized_fallback = await self.normalize_agent_model_assignments( + gateway_id=agent.gateway_id, + primary_model_id=primary_model_id if isinstance(primary_model_id, UUID) else None, + fallback_model_ids=fallback_model_ids, + ) + agent.primary_model_id = normalized_primary + agent.fallback_model_ids = normalized_fallback + agent.updated_at = utcnow() if agent.heartbeat_config is None: agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy() @@ -1487,6 +1574,17 @@ class AgentLifecycleService(OpenClawDBService): gateway, _client_config = await self.require_gateway(board) data = payload.model_dump() data["gateway_id"] = gateway.id + primary_model_id = data.get("primary_model_id") + if primary_model_id is not None and not isinstance(primary_model_id, UUID): + primary_model_id = UUID(str(primary_model_id)) + fallback_model_ids = self._normalized_fallback_ids(data.get("fallback_model_ids")) + normalized_primary, normalized_fallback = await self.normalize_agent_model_assignments( + gateway_id=gateway.id, + primary_model_id=primary_model_id if isinstance(primary_model_id, UUID) else None, + fallback_model_ids=fallback_model_ids, + ) + data["primary_model_id"] = normalized_primary + data["fallback_model_ids"] = normalized_fallback requested_name = (data.get("name") or "").strip() await self.ensure_unique_agent_name( board=board, @@ -1502,6 +1600,8 @@ class AgentLifecycleService(OpenClawDBService): user=actor.user if actor.actor_type == "user" else None, force_bootstrap=False, ) + if agent.primary_model_id is not None or agent.fallback_model_ids: + await self.sync_gateway_agent_models(gateway=gateway) self.logger.info("agent.create.success agent_id=%s board_id=%s", agent.id, board.id) return self.to_agent_read(self.with_computed_status(agent)) @@ -1535,6 +1635,7 @@ class AgentLifecycleService(OpenClawDBService): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) await self.require_agent_access(agent=agent, ctx=options.context, write=True) updates = payload.model_dump(exclude_unset=True) + sync_model_assignments = "primary_model_id" in updates or "fallback_model_ids" in updates make_main = updates.pop("is_gateway_main", None) await self.validate_agent_update_inputs( ctx=options.context, @@ -1568,6 +1669,8 @@ class AgentLifecycleService(OpenClawDBService): agent=agent, request=provision_request, ) + if sync_model_assignments or agent.primary_model_id is not None or agent.fallback_model_ids: + await self.sync_gateway_agent_models(gateway=target.gateway) self.logger.info("agent.update.success agent_id=%s", agent.id) return self.to_agent_read(self.with_computed_status(agent)) diff --git a/backend/migrations/versions/9a3fb1158c2d_add_llm_model_registry.py b/backend/migrations/versions/9a3fb1158c2d_add_llm_model_registry.py new file mode 100644 index 00000000..b4a16cb5 --- /dev/null +++ b/backend/migrations/versions/9a3fb1158c2d_add_llm_model_registry.py @@ -0,0 +1,285 @@ +"""add llm model registry + +Revision ID: 9a3fb1158c2d +Revises: f4d2b649e93a +Create Date: 2026-02-11 21:00:00.000000 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "9a3fb1158c2d" +down_revision = "f4d2b649e93a" +branch_labels = None +depends_on = None + + +def _has_index(inspector: sa.Inspector, table: str, index_name: str) -> bool: + return any(item.get("name") == index_name for item in inspector.get_indexes(table)) + + +def _has_unique( + inspector: sa.Inspector, + table: str, + *, + name: str | None = None, + columns: tuple[str, ...] | None = None, +) -> bool: + unique_constraints = inspector.get_unique_constraints(table) + for item in unique_constraints: + if name and item.get("name") == name: + return True + if columns and tuple(item.get("column_names") or ()) == columns: + return True + return False + + +def _column_names(inspector: sa.Inspector, table: str) -> set[str]: + return {item["name"] for item in inspector.get_columns(table)} + + +def _has_foreign_key( + inspector: sa.Inspector, + table: str, + *, + constrained_columns: tuple[str, ...], + referred_table: str, +) -> bool: + for item in inspector.get_foreign_keys(table): + if tuple(item.get("constrained_columns") or ()) != constrained_columns: + continue + if item.get("referred_table") != referred_table: + continue + return True + return False + + +def upgrade() -> None: + bind = op.get_bind() + inspector = sa.inspect(bind) + + if not inspector.has_table("llm_provider_auth"): + op.create_table( + "llm_provider_auth", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("organization_id", sa.Uuid(), nullable=False), + sa.Column("gateway_id", sa.Uuid(), nullable=False), + sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("config_path", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("secret", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(["gateway_id"], ["gateways.id"]), + sa.ForeignKeyConstraint(["organization_id"], ["organizations.id"]), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint( + "gateway_id", + "provider", + "config_path", + name="uq_llm_provider_auth_gateway_provider_path", + ), + ) + inspector = sa.inspect(bind) + else: + existing_columns = _column_names(inspector, "llm_provider_auth") + if "config_path" not in existing_columns: + op.add_column( + "llm_provider_auth", + sa.Column( + "config_path", + sqlmodel.sql.sqltypes.AutoString(), + nullable=False, + server_default="providers.openai.apiKey", + ), + ) + op.alter_column("llm_provider_auth", "config_path", server_default=None) + inspector = sa.inspect(bind) + if not _has_unique( + inspector, + "llm_provider_auth", + name="uq_llm_provider_auth_gateway_provider_path", + columns=("gateway_id", "provider", "config_path"), + ): + op.create_unique_constraint( + "uq_llm_provider_auth_gateway_provider_path", + "llm_provider_auth", + ["gateway_id", "provider", "config_path"], + ) + inspector = sa.inspect(bind) + + if not _has_index(inspector, "llm_provider_auth", op.f("ix_llm_provider_auth_gateway_id")): + op.create_index( + op.f("ix_llm_provider_auth_gateway_id"), + "llm_provider_auth", + ["gateway_id"], + unique=False, + ) + inspector = sa.inspect(bind) + if not _has_index( + inspector, + "llm_provider_auth", + op.f("ix_llm_provider_auth_organization_id"), + ): + op.create_index( + op.f("ix_llm_provider_auth_organization_id"), + "llm_provider_auth", + ["organization_id"], + unique=False, + ) + inspector = sa.inspect(bind) + if not _has_index(inspector, "llm_provider_auth", op.f("ix_llm_provider_auth_provider")): + op.create_index( + op.f("ix_llm_provider_auth_provider"), + "llm_provider_auth", + ["provider"], + unique=False, + ) + inspector = sa.inspect(bind) + + if not inspector.has_table("llm_models"): + op.create_table( + "llm_models", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("organization_id", sa.Uuid(), nullable=False), + sa.Column("gateway_id", sa.Uuid(), nullable=False), + sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("model_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("display_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("settings", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(["gateway_id"], ["gateways.id"]), + sa.ForeignKeyConstraint(["organization_id"], ["organizations.id"]), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("gateway_id", "model_id", name="uq_llm_models_gateway_model_id"), + ) + inspector = sa.inspect(bind) + elif not _has_unique( + inspector, + "llm_models", + name="uq_llm_models_gateway_model_id", + columns=("gateway_id", "model_id"), + ): + op.create_unique_constraint( + "uq_llm_models_gateway_model_id", + "llm_models", + ["gateway_id", "model_id"], + ) + inspector = sa.inspect(bind) + + if not _has_index(inspector, "llm_models", op.f("ix_llm_models_gateway_id")): + op.create_index( + op.f("ix_llm_models_gateway_id"), + "llm_models", + ["gateway_id"], + unique=False, + ) + inspector = sa.inspect(bind) + if not _has_index(inspector, "llm_models", op.f("ix_llm_models_model_id")): + op.create_index( + op.f("ix_llm_models_model_id"), + "llm_models", + ["model_id"], + unique=False, + ) + inspector = sa.inspect(bind) + if not _has_index(inspector, "llm_models", op.f("ix_llm_models_organization_id")): + op.create_index( + op.f("ix_llm_models_organization_id"), + "llm_models", + ["organization_id"], + unique=False, + ) + inspector = sa.inspect(bind) + if not _has_index(inspector, "llm_models", op.f("ix_llm_models_provider")): + op.create_index( + op.f("ix_llm_models_provider"), + "llm_models", + ["provider"], + unique=False, + ) + inspector = sa.inspect(bind) + + agent_columns = _column_names(inspector, "agents") + if "primary_model_id" not in agent_columns: + op.add_column("agents", sa.Column("primary_model_id", sa.Uuid(), nullable=True)) + inspector = sa.inspect(bind) + if "fallback_model_ids" not in agent_columns: + op.add_column("agents", sa.Column("fallback_model_ids", sa.JSON(), nullable=True)) + inspector = sa.inspect(bind) + if not _has_index(inspector, "agents", op.f("ix_agents_primary_model_id")): + op.create_index(op.f("ix_agents_primary_model_id"), "agents", ["primary_model_id"], unique=False) + inspector = sa.inspect(bind) + if not _has_foreign_key( + inspector, + "agents", + constrained_columns=("primary_model_id",), + referred_table="llm_models", + ): + op.create_foreign_key( + "fk_agents_primary_model_id_llm_models", + "agents", + "llm_models", + ["primary_model_id"], + ["id"], + ) + + +def downgrade() -> None: + bind = op.get_bind() + inspector = sa.inspect(bind) + + if inspector.has_table("agents"): + for fk in inspector.get_foreign_keys("agents"): + if tuple(fk.get("constrained_columns") or ()) != ("primary_model_id",): + continue + if fk.get("referred_table") != "llm_models": + continue + fk_name = fk.get("name") + if fk_name: + op.drop_constraint(fk_name, "agents", type_="foreignkey") + inspector = sa.inspect(bind) + if _has_index(inspector, "agents", op.f("ix_agents_primary_model_id")): + op.drop_index(op.f("ix_agents_primary_model_id"), table_name="agents") + agent_columns = _column_names(inspector, "agents") + if "fallback_model_ids" in agent_columns: + op.drop_column("agents", "fallback_model_ids") + if "primary_model_id" in agent_columns: + op.drop_column("agents", "primary_model_id") + + inspector = sa.inspect(bind) + if inspector.has_table("llm_models"): + if _has_index(inspector, "llm_models", op.f("ix_llm_models_provider")): + op.drop_index(op.f("ix_llm_models_provider"), table_name="llm_models") + if _has_index(inspector, "llm_models", op.f("ix_llm_models_organization_id")): + op.drop_index(op.f("ix_llm_models_organization_id"), table_name="llm_models") + if _has_index(inspector, "llm_models", op.f("ix_llm_models_model_id")): + op.drop_index(op.f("ix_llm_models_model_id"), table_name="llm_models") + if _has_index(inspector, "llm_models", op.f("ix_llm_models_gateway_id")): + op.drop_index(op.f("ix_llm_models_gateway_id"), table_name="llm_models") + op.drop_table("llm_models") + + inspector = sa.inspect(bind) + if inspector.has_table("llm_provider_auth"): + if _has_index(inspector, "llm_provider_auth", op.f("ix_llm_provider_auth_provider")): + op.drop_index(op.f("ix_llm_provider_auth_provider"), table_name="llm_provider_auth") + if _has_index( + inspector, + "llm_provider_auth", + op.f("ix_llm_provider_auth_organization_id"), + ): + op.drop_index( + op.f("ix_llm_provider_auth_organization_id"), + table_name="llm_provider_auth", + ) + if _has_index(inspector, "llm_provider_auth", op.f("ix_llm_provider_auth_gateway_id")): + op.drop_index( + op.f("ix_llm_provider_auth_gateway_id"), + table_name="llm_provider_auth", + ) + op.drop_table("llm_provider_auth") diff --git a/backend/tests/test_agent_model_assignment_updates.py b/backend/tests/test_agent_model_assignment_updates.py new file mode 100644 index 00000000..7d42c8a9 --- /dev/null +++ b/backend/tests/test_agent_model_assignment_updates.py @@ -0,0 +1,105 @@ +# ruff: noqa: S101 +"""Regression tests for agent model-assignment update normalization.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any +from uuid import UUID, uuid4 + +import pytest +from unittest.mock import AsyncMock + +from app.services.openclaw.provisioning_db import AgentLifecycleService + + +class _NoAutoflush: + def __init__(self, session: "_SessionStub") -> None: + self._session = session + + def __enter__(self) -> None: + self._session.in_no_autoflush = True + return None + + def __exit__(self, exc_type, exc, tb) -> bool: + self._session.in_no_autoflush = False + return False + + +class _SessionStub: + def __init__(self, valid_ids: set[UUID]) -> None: + self.valid_ids = valid_ids + self.in_no_autoflush = False + self.commits = 0 + + @property + def no_autoflush(self) -> _NoAutoflush: + return _NoAutoflush(self) + + async def exec(self, _statement: Any) -> list[UUID]: + if not self.in_no_autoflush: + raise AssertionError("Expected normalize query to run under no_autoflush.") + return list(self.valid_ids) + + def add(self, _model: Any) -> None: + return None + + async def commit(self) -> None: + self.commits += 1 + + async def refresh(self, _model: Any) -> None: + return None + + +@dataclass +class _AgentStub: + gateway_id: UUID + board_id: UUID | None = None + is_board_lead: bool = False + openclaw_session_id: str | None = None + primary_model_id: UUID | None = None + fallback_model_ids: list[str] | None = None + updated_at: datetime | None = None + heartbeat_config: dict[str, Any] | None = field( + default_factory=lambda: {"interval_seconds": 5}, + ) + + +@pytest.mark.asyncio +async def test_normalize_agent_model_assignments_uses_no_autoflush() -> None: + primary = uuid4() + fallback = uuid4() + session = _SessionStub({primary, fallback}) + service = AgentLifecycleService(session) # type: ignore[arg-type] + + normalized_primary, normalized_fallback = await service.normalize_agent_model_assignments( + gateway_id=uuid4(), + primary_model_id=primary, + fallback_model_ids=[primary, fallback], + ) + + assert normalized_primary == primary + assert normalized_fallback == [str(fallback)] + + +@pytest.mark.asyncio +async def test_apply_agent_update_mutations_coerces_fallback_ids_to_strings(monkeypatch) -> None: + primary = uuid4() + fallback = uuid4() + session = _SessionStub({primary, fallback}) + service = AgentLifecycleService(session) # type: ignore[arg-type] + monkeypatch.setattr(service, "get_main_agent_gateway", AsyncMock(return_value=None)) + + agent = _AgentStub(gateway_id=uuid4()) + updates: dict[str, Any] = { + "primary_model_id": primary, + "fallback_model_ids": [primary, fallback, fallback], + } + + await service.apply_agent_update_mutations(agent=agent, updates=updates, make_main=None) # type: ignore[arg-type] + + assert updates["fallback_model_ids"] == [str(primary), str(fallback)] + assert agent.primary_model_id == primary + assert agent.fallback_model_ids == [str(fallback)] + assert session.commits == 1 diff --git a/backend/tests/test_agent_schema_models.py b/backend/tests/test_agent_schema_models.py new file mode 100644 index 00000000..3856b167 --- /dev/null +++ b/backend/tests/test_agent_schema_models.py @@ -0,0 +1,28 @@ +# ruff: noqa: S101 +"""Tests for agent model-assignment schema normalization.""" + +from __future__ import annotations + +from uuid import uuid4 + +import pytest +from pydantic import ValidationError + +from app.schemas.agents import AgentCreate + + +def test_agent_create_normalizes_fallback_model_ids() -> None: + model_a = uuid4() + model_b = uuid4() + + payload = AgentCreate( + name="Worker", + fallback_model_ids=[str(model_a), str(model_b), str(model_a)], + ) + + assert payload.fallback_model_ids == [model_a, model_b] + + +def test_agent_create_rejects_non_list_fallback_model_ids() -> None: + with pytest.raises(ValidationError): + AgentCreate(name="Worker", fallback_model_ids="not-a-list") diff --git a/backend/tests/test_model_registry_pull_helpers.py b/backend/tests/test_model_registry_pull_helpers.py new file mode 100644 index 00000000..b019509d --- /dev/null +++ b/backend/tests/test_model_registry_pull_helpers.py @@ -0,0 +1,110 @@ +# ruff: noqa: S101 +"""Tests for gateway model-registry pull helper normalization.""" + +from __future__ import annotations + +from app.services.openclaw.model_registry_service import ( + _extract_config_data, + _get_nested_path, + _infer_provider_for_model, + _model_config, + _model_settings, + _normalize_provider, + _parse_agent_model_value, +) + + +def test_get_nested_path_resolves_existing_value() -> None: + source = {"providers": {"openai": {"apiKey": "sk-test"}}} + + assert _get_nested_path(source, ["providers", "openai", "apiKey"]) == "sk-test" + assert _get_nested_path(source, ["providers", "anthropic", "apiKey"]) is None + + +def test_normalize_provider_trims_and_lowercases() -> None: + assert _normalize_provider(" OpenAI ") == "openai" + assert _normalize_provider("") is None + assert _normalize_provider(123) is None + + +def test_infer_provider_for_model_prefers_prefix_delimiter() -> None: + assert _infer_provider_for_model("openai/gpt-5") == "openai" + assert _infer_provider_for_model("anthropic:claude-sonnet") == "anthropic" + assert _infer_provider_for_model("gpt-5") == "unknown" + + +def test_model_settings_only_accepts_dict_payloads() -> None: + settings = _model_settings({"provider": "openai", "temperature": 0.2}) + + assert settings == {"provider": "openai", "temperature": 0.2} + assert _model_settings("not-a-dict") is None + + +def test_parse_agent_model_value_normalizes_primary_and_fallbacks() -> None: + primary, fallback = _parse_agent_model_value( + { + "primary": " openai/gpt-5 ", + "fallbacks": [ + "openai/gpt-4.1", + "openai/gpt-5", + "openai/gpt-4.1", + " ", + 123, + ], + }, + ) + + assert primary == "openai/gpt-5" + assert fallback == ["openai/gpt-4.1"] + + +def test_parse_agent_model_value_accepts_legacy_fallback_key() -> None: + primary, fallback = _parse_agent_model_value( + { + "primary": "openai/gpt-5", + "fallback": ["openai/gpt-4.1", "openai/gpt-4.1"], + }, + ) + + assert primary == "openai/gpt-5" + assert fallback == ["openai/gpt-4.1"] + + +def test_parse_agent_model_value_accepts_string_primary() -> None: + primary, fallback = _parse_agent_model_value(" openai/gpt-5 ") + + assert primary == "openai/gpt-5" + assert fallback == [] + + +def test_model_config_uses_fallbacks_key() -> None: + assert _model_config("openai/gpt-5", ["openai/gpt-4.1"]) == { + "primary": "openai/gpt-5", + "fallbacks": ["openai/gpt-4.1"], + } + + +def test_extract_config_data_prefers_parsed_when_config_is_raw_string() -> None: + config_data, base_hash = _extract_config_data( + { + "config": '{"agents":{"list":[{"id":"a1"}]}}', + "parsed": {"agents": {"list": [{"id": "a1"}]}}, + "hash": "abc123", + }, + ) + + assert isinstance(config_data, dict) + assert config_data.get("agents") == {"list": [{"id": "a1"}]} + assert base_hash == "abc123" + + +def test_extract_config_data_parses_json_string_when_parsed_absent() -> None: + config_data, base_hash = _extract_config_data( + { + "config": '{"providers":{"openai":{"apiKey":"sk-test"}}}', + "hash": "def456", + }, + ) + + assert config_data.get("providers") == {"openai": {"apiKey": "sk-test"}} + assert base_hash == "def456" diff --git a/frontend/src/api/generated/model-registry/model-registry.ts b/frontend/src/api/generated/model-registry/model-registry.ts new file mode 100644 index 00000000..490f2a0f --- /dev/null +++ b/frontend/src/api/generated/model-registry/model-registry.ts @@ -0,0 +1,1498 @@ +/** + * Generated by orval v8.2.0 🍺 + * Do not edit manually. + * Mission Control API + * OpenAPI spec version: 0.1.0 + */ +import { useMutation, useQuery } from "@tanstack/react-query"; +import type { + DataTag, + DefinedInitialDataOptions, + DefinedUseQueryResult, + MutationFunction, + QueryClient, + QueryFunction, + QueryKey, + UndefinedInitialDataOptions, + UseMutationOptions, + UseMutationResult, + UseQueryOptions, + UseQueryResult, +} from "@tanstack/react-query"; + +import type { + GatewayModelSyncResult, + HTTPValidationError, + ListModelsApiV1ModelRegistryModelsGetParams, + ListProviderAuthApiV1ModelRegistryProviderAuthGetParams, + LlmModelCreate, + LlmModelRead, + LlmModelUpdate, + LlmProviderAuthCreate, + LlmProviderAuthRead, + LlmProviderAuthUpdate, + OkResponse, +} from ".././model"; + +import { customFetch } from "../../mutator"; + +type SecondParameter unknown> = Parameters[1]; + +/** + * List provider auth records for the active organization. + * @summary List Provider Auth + */ +export type listProviderAuthApiV1ModelRegistryProviderAuthGetResponse200 = { + data: LlmProviderAuthRead[]; + status: 200; +}; + +export type listProviderAuthApiV1ModelRegistryProviderAuthGetResponse422 = { + data: HTTPValidationError; + status: 422; +}; + +export type listProviderAuthApiV1ModelRegistryProviderAuthGetResponseSuccess = + listProviderAuthApiV1ModelRegistryProviderAuthGetResponse200 & { + headers: Headers; + }; +export type listProviderAuthApiV1ModelRegistryProviderAuthGetResponseError = + listProviderAuthApiV1ModelRegistryProviderAuthGetResponse422 & { + headers: Headers; + }; + +export type listProviderAuthApiV1ModelRegistryProviderAuthGetResponse = + | listProviderAuthApiV1ModelRegistryProviderAuthGetResponseSuccess + | listProviderAuthApiV1ModelRegistryProviderAuthGetResponseError; + +export const getListProviderAuthApiV1ModelRegistryProviderAuthGetUrl = ( + params?: ListProviderAuthApiV1ModelRegistryProviderAuthGetParams, +) => { + const normalizedParams = new URLSearchParams(); + + Object.entries(params || {}).forEach(([key, value]) => { + if (value !== undefined) { + normalizedParams.append(key, value === null ? "null" : value.toString()); + } + }); + + const stringifiedParams = normalizedParams.toString(); + + return stringifiedParams.length > 0 + ? `/api/v1/model-registry/provider-auth?${stringifiedParams}` + : `/api/v1/model-registry/provider-auth`; +}; + +export const listProviderAuthApiV1ModelRegistryProviderAuthGet = async ( + params?: ListProviderAuthApiV1ModelRegistryProviderAuthGetParams, + options?: RequestInit, +): Promise => { + return customFetch( + getListProviderAuthApiV1ModelRegistryProviderAuthGetUrl(params), + { + ...options, + method: "GET", + }, + ); +}; + +export const getListProviderAuthApiV1ModelRegistryProviderAuthGetQueryKey = ( + params?: ListProviderAuthApiV1ModelRegistryProviderAuthGetParams, +) => { + return [ + `/api/v1/model-registry/provider-auth`, + ...(params ? [params] : []), + ] as const; +}; + +export const getListProviderAuthApiV1ModelRegistryProviderAuthGetQueryOptions = + < + TData = Awaited< + ReturnType + >, + TError = HTTPValidationError, + >( + params?: ListProviderAuthApiV1ModelRegistryProviderAuthGetParams, + options?: { + query?: Partial< + UseQueryOptions< + Awaited< + ReturnType + >, + TError, + TData + > + >; + request?: SecondParameter; + }, + ) => { + const { query: queryOptions, request: requestOptions } = options ?? {}; + + const queryKey = + queryOptions?.queryKey ?? + getListProviderAuthApiV1ModelRegistryProviderAuthGetQueryKey(params); + + const queryFn: QueryFunction< + Awaited< + ReturnType + > + > = ({ signal }) => + listProviderAuthApiV1ModelRegistryProviderAuthGet(params, { + signal, + ...requestOptions, + }); + + return { queryKey, queryFn, ...queryOptions } as UseQueryOptions< + Awaited< + ReturnType + >, + TError, + TData + > & { queryKey: DataTag }; + }; + +export type ListProviderAuthApiV1ModelRegistryProviderAuthGetQueryResult = + NonNullable< + Awaited< + ReturnType + > + >; +export type ListProviderAuthApiV1ModelRegistryProviderAuthGetQueryError = + HTTPValidationError; + +export function useListProviderAuthApiV1ModelRegistryProviderAuthGet< + TData = Awaited< + ReturnType + >, + TError = HTTPValidationError, +>( + params: undefined | ListProviderAuthApiV1ModelRegistryProviderAuthGetParams, + options: { + query: Partial< + UseQueryOptions< + Awaited< + ReturnType + >, + TError, + TData + > + > & + Pick< + DefinedInitialDataOptions< + Awaited< + ReturnType + >, + TError, + Awaited< + ReturnType + > + >, + "initialData" + >; + request?: SecondParameter; + }, + queryClient?: QueryClient, +): DefinedUseQueryResult & { + queryKey: DataTag; +}; +export function useListProviderAuthApiV1ModelRegistryProviderAuthGet< + TData = Awaited< + ReturnType + >, + TError = HTTPValidationError, +>( + params?: ListProviderAuthApiV1ModelRegistryProviderAuthGetParams, + options?: { + query?: Partial< + UseQueryOptions< + Awaited< + ReturnType + >, + TError, + TData + > + > & + Pick< + UndefinedInitialDataOptions< + Awaited< + ReturnType + >, + TError, + Awaited< + ReturnType + > + >, + "initialData" + >; + request?: SecondParameter; + }, + queryClient?: QueryClient, +): UseQueryResult & { + queryKey: DataTag; +}; +export function useListProviderAuthApiV1ModelRegistryProviderAuthGet< + TData = Awaited< + ReturnType + >, + TError = HTTPValidationError, +>( + params?: ListProviderAuthApiV1ModelRegistryProviderAuthGetParams, + options?: { + query?: Partial< + UseQueryOptions< + Awaited< + ReturnType + >, + TError, + TData + > + >; + request?: SecondParameter; + }, + queryClient?: QueryClient, +): UseQueryResult & { + queryKey: DataTag; +}; +/** + * @summary List Provider Auth + */ + +export function useListProviderAuthApiV1ModelRegistryProviderAuthGet< + TData = Awaited< + ReturnType + >, + TError = HTTPValidationError, +>( + params?: ListProviderAuthApiV1ModelRegistryProviderAuthGetParams, + options?: { + query?: Partial< + UseQueryOptions< + Awaited< + ReturnType + >, + TError, + TData + > + >; + request?: SecondParameter; + }, + queryClient?: QueryClient, +): UseQueryResult & { + queryKey: DataTag; +} { + const queryOptions = + getListProviderAuthApiV1ModelRegistryProviderAuthGetQueryOptions( + params, + options, + ); + + const query = useQuery(queryOptions, queryClient) as UseQueryResult< + TData, + TError + > & { queryKey: DataTag }; + + return { ...query, queryKey: queryOptions.queryKey }; +} + +/** + * Create a provider auth record and sync gateway config. + * @summary Create Provider Auth + */ +export type createProviderAuthApiV1ModelRegistryProviderAuthPostResponse200 = { + data: LlmProviderAuthRead; + status: 200; +}; + +export type createProviderAuthApiV1ModelRegistryProviderAuthPostResponse422 = { + data: HTTPValidationError; + status: 422; +}; + +export type createProviderAuthApiV1ModelRegistryProviderAuthPostResponseSuccess = + createProviderAuthApiV1ModelRegistryProviderAuthPostResponse200 & { + headers: Headers; + }; +export type createProviderAuthApiV1ModelRegistryProviderAuthPostResponseError = + createProviderAuthApiV1ModelRegistryProviderAuthPostResponse422 & { + headers: Headers; + }; + +export type createProviderAuthApiV1ModelRegistryProviderAuthPostResponse = + | createProviderAuthApiV1ModelRegistryProviderAuthPostResponseSuccess + | createProviderAuthApiV1ModelRegistryProviderAuthPostResponseError; + +export const getCreateProviderAuthApiV1ModelRegistryProviderAuthPostUrl = + () => { + return `/api/v1/model-registry/provider-auth`; + }; + +export const createProviderAuthApiV1ModelRegistryProviderAuthPost = async ( + llmProviderAuthCreate: LlmProviderAuthCreate, + options?: RequestInit, +): Promise => { + return customFetch( + getCreateProviderAuthApiV1ModelRegistryProviderAuthPostUrl(), + { + ...options, + method: "POST", + headers: { "Content-Type": "application/json", ...options?.headers }, + body: JSON.stringify(llmProviderAuthCreate), + }, + ); +}; + +export const getCreateProviderAuthApiV1ModelRegistryProviderAuthPostMutationOptions = + (options?: { + mutation?: UseMutationOptions< + Awaited< + ReturnType + >, + TError, + { data: LlmProviderAuthCreate }, + TContext + >; + request?: SecondParameter; + }): UseMutationOptions< + Awaited< + ReturnType + >, + TError, + { data: LlmProviderAuthCreate }, + TContext + > => { + const mutationKey = [ + "createProviderAuthApiV1ModelRegistryProviderAuthPost", + ]; + const { mutation: mutationOptions, request: requestOptions } = options + ? options.mutation && + "mutationKey" in options.mutation && + options.mutation.mutationKey + ? options + : { ...options, mutation: { ...options.mutation, mutationKey } } + : { mutation: { mutationKey }, request: undefined }; + + const mutationFn: MutationFunction< + Awaited< + ReturnType + >, + { data: LlmProviderAuthCreate } + > = (props) => { + const { data } = props ?? {}; + + return createProviderAuthApiV1ModelRegistryProviderAuthPost( + data, + requestOptions, + ); + }; + + return { mutationFn, ...mutationOptions }; + }; + +export type CreateProviderAuthApiV1ModelRegistryProviderAuthPostMutationResult = + NonNullable< + Awaited< + ReturnType + > + >; +export type CreateProviderAuthApiV1ModelRegistryProviderAuthPostMutationBody = + LlmProviderAuthCreate; +export type CreateProviderAuthApiV1ModelRegistryProviderAuthPostMutationError = + HTTPValidationError; + +/** + * @summary Create Provider Auth + */ +export const useCreateProviderAuthApiV1ModelRegistryProviderAuthPost = < + TError = HTTPValidationError, + TContext = unknown, +>( + options?: { + mutation?: UseMutationOptions< + Awaited< + ReturnType + >, + TError, + { data: LlmProviderAuthCreate }, + TContext + >; + request?: SecondParameter; + }, + queryClient?: QueryClient, +): UseMutationResult< + Awaited< + ReturnType + >, + TError, + { data: LlmProviderAuthCreate }, + TContext +> => { + return useMutation( + getCreateProviderAuthApiV1ModelRegistryProviderAuthPostMutationOptions( + options, + ), + queryClient, + ); +}; +/** + * Patch a provider auth record and sync gateway config. + * @summary Update Provider Auth + */ +export type updateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatchResponse200 = + { + data: LlmProviderAuthRead; + status: 200; + }; + +export type updateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatchResponse422 = + { + data: HTTPValidationError; + status: 422; + }; + +export type updateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatchResponseSuccess = + updateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatchResponse200 & { + headers: Headers; + }; +export type updateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatchResponseError = + updateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatchResponse422 & { + headers: Headers; + }; + +export type updateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatchResponse = + + | updateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatchResponseSuccess + | updateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatchResponseError; + +export const getUpdateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatchUrl = + (providerAuthId: string) => { + return `/api/v1/model-registry/provider-auth/${providerAuthId}`; + }; + +export const updateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatch = + async ( + providerAuthId: string, + llmProviderAuthUpdate: LlmProviderAuthUpdate, + options?: RequestInit, + ): Promise => { + return customFetch( + getUpdateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatchUrl( + providerAuthId, + ), + { + ...options, + method: "PATCH", + headers: { "Content-Type": "application/json", ...options?.headers }, + body: JSON.stringify(llmProviderAuthUpdate), + }, + ); + }; + +export const getUpdateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatchMutationOptions = + (options?: { + mutation?: UseMutationOptions< + Awaited< + ReturnType< + typeof updateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatch + > + >, + TError, + { providerAuthId: string; data: LlmProviderAuthUpdate }, + TContext + >; + request?: SecondParameter; + }): UseMutationOptions< + Awaited< + ReturnType< + typeof updateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatch + > + >, + TError, + { providerAuthId: string; data: LlmProviderAuthUpdate }, + TContext + > => { + const mutationKey = [ + "updateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatch", + ]; + const { mutation: mutationOptions, request: requestOptions } = options + ? options.mutation && + "mutationKey" in options.mutation && + options.mutation.mutationKey + ? options + : { ...options, mutation: { ...options.mutation, mutationKey } } + : { mutation: { mutationKey }, request: undefined }; + + const mutationFn: MutationFunction< + Awaited< + ReturnType< + typeof updateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatch + > + >, + { providerAuthId: string; data: LlmProviderAuthUpdate } + > = (props) => { + const { providerAuthId, data } = props ?? {}; + + return updateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatch( + providerAuthId, + data, + requestOptions, + ); + }; + + return { mutationFn, ...mutationOptions }; + }; + +export type UpdateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatchMutationResult = + NonNullable< + Awaited< + ReturnType< + typeof updateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatch + > + > + >; +export type UpdateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatchMutationBody = + LlmProviderAuthUpdate; +export type UpdateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatchMutationError = + HTTPValidationError; + +/** + * @summary Update Provider Auth + */ +export const useUpdateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatch = + ( + options?: { + mutation?: UseMutationOptions< + Awaited< + ReturnType< + typeof updateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatch + > + >, + TError, + { providerAuthId: string; data: LlmProviderAuthUpdate }, + TContext + >; + request?: SecondParameter; + }, + queryClient?: QueryClient, + ): UseMutationResult< + Awaited< + ReturnType< + typeof updateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatch + > + >, + TError, + { providerAuthId: string; data: LlmProviderAuthUpdate }, + TContext + > => { + return useMutation( + getUpdateProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdPatchMutationOptions( + options, + ), + queryClient, + ); + }; +/** + * Delete a provider auth record and sync gateway config. + * @summary Delete Provider Auth + */ +export type deleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDeleteResponse200 = + { + data: OkResponse; + status: 200; + }; + +export type deleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDeleteResponse422 = + { + data: HTTPValidationError; + status: 422; + }; + +export type deleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDeleteResponseSuccess = + deleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDeleteResponse200 & { + headers: Headers; + }; +export type deleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDeleteResponseError = + deleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDeleteResponse422 & { + headers: Headers; + }; + +export type deleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDeleteResponse = + + | deleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDeleteResponseSuccess + | deleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDeleteResponseError; + +export const getDeleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDeleteUrl = + (providerAuthId: string) => { + return `/api/v1/model-registry/provider-auth/${providerAuthId}`; + }; + +export const deleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDelete = + async ( + providerAuthId: string, + options?: RequestInit, + ): Promise => { + return customFetch( + getDeleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDeleteUrl( + providerAuthId, + ), + { + ...options, + method: "DELETE", + }, + ); + }; + +export const getDeleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDeleteMutationOptions = + (options?: { + mutation?: UseMutationOptions< + Awaited< + ReturnType< + typeof deleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDelete + > + >, + TError, + { providerAuthId: string }, + TContext + >; + request?: SecondParameter; + }): UseMutationOptions< + Awaited< + ReturnType< + typeof deleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDelete + > + >, + TError, + { providerAuthId: string }, + TContext + > => { + const mutationKey = [ + "deleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDelete", + ]; + const { mutation: mutationOptions, request: requestOptions } = options + ? options.mutation && + "mutationKey" in options.mutation && + options.mutation.mutationKey + ? options + : { ...options, mutation: { ...options.mutation, mutationKey } } + : { mutation: { mutationKey }, request: undefined }; + + const mutationFn: MutationFunction< + Awaited< + ReturnType< + typeof deleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDelete + > + >, + { providerAuthId: string } + > = (props) => { + const { providerAuthId } = props ?? {}; + + return deleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDelete( + providerAuthId, + requestOptions, + ); + }; + + return { mutationFn, ...mutationOptions }; + }; + +export type DeleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDeleteMutationResult = + NonNullable< + Awaited< + ReturnType< + typeof deleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDelete + > + > + >; + +export type DeleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDeleteMutationError = + HTTPValidationError; + +/** + * @summary Delete Provider Auth + */ +export const useDeleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDelete = + ( + options?: { + mutation?: UseMutationOptions< + Awaited< + ReturnType< + typeof deleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDelete + > + >, + TError, + { providerAuthId: string }, + TContext + >; + request?: SecondParameter; + }, + queryClient?: QueryClient, + ): UseMutationResult< + Awaited< + ReturnType< + typeof deleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDelete + > + >, + TError, + { providerAuthId: string }, + TContext + > => { + return useMutation( + getDeleteProviderAuthApiV1ModelRegistryProviderAuthProviderAuthIdDeleteMutationOptions( + options, + ), + queryClient, + ); + }; +/** + * List gateway model catalog entries for the active organization. + * @summary List Models + */ +export type listModelsApiV1ModelRegistryModelsGetResponse200 = { + data: LlmModelRead[]; + status: 200; +}; + +export type listModelsApiV1ModelRegistryModelsGetResponse422 = { + data: HTTPValidationError; + status: 422; +}; + +export type listModelsApiV1ModelRegistryModelsGetResponseSuccess = + listModelsApiV1ModelRegistryModelsGetResponse200 & { + headers: Headers; + }; +export type listModelsApiV1ModelRegistryModelsGetResponseError = + listModelsApiV1ModelRegistryModelsGetResponse422 & { + headers: Headers; + }; + +export type listModelsApiV1ModelRegistryModelsGetResponse = + | listModelsApiV1ModelRegistryModelsGetResponseSuccess + | listModelsApiV1ModelRegistryModelsGetResponseError; + +export const getListModelsApiV1ModelRegistryModelsGetUrl = ( + params?: ListModelsApiV1ModelRegistryModelsGetParams, +) => { + const normalizedParams = new URLSearchParams(); + + Object.entries(params || {}).forEach(([key, value]) => { + if (value !== undefined) { + normalizedParams.append(key, value === null ? "null" : value.toString()); + } + }); + + const stringifiedParams = normalizedParams.toString(); + + return stringifiedParams.length > 0 + ? `/api/v1/model-registry/models?${stringifiedParams}` + : `/api/v1/model-registry/models`; +}; + +export const listModelsApiV1ModelRegistryModelsGet = async ( + params?: ListModelsApiV1ModelRegistryModelsGetParams, + options?: RequestInit, +): Promise => { + return customFetch( + getListModelsApiV1ModelRegistryModelsGetUrl(params), + { + ...options, + method: "GET", + }, + ); +}; + +export const getListModelsApiV1ModelRegistryModelsGetQueryKey = ( + params?: ListModelsApiV1ModelRegistryModelsGetParams, +) => { + return [ + `/api/v1/model-registry/models`, + ...(params ? [params] : []), + ] as const; +}; + +export const getListModelsApiV1ModelRegistryModelsGetQueryOptions = < + TData = Awaited>, + TError = HTTPValidationError, +>( + params?: ListModelsApiV1ModelRegistryModelsGetParams, + options?: { + query?: Partial< + UseQueryOptions< + Awaited>, + TError, + TData + > + >; + request?: SecondParameter; + }, +) => { + const { query: queryOptions, request: requestOptions } = options ?? {}; + + const queryKey = + queryOptions?.queryKey ?? + getListModelsApiV1ModelRegistryModelsGetQueryKey(params); + + const queryFn: QueryFunction< + Awaited> + > = ({ signal }) => + listModelsApiV1ModelRegistryModelsGet(params, { + signal, + ...requestOptions, + }); + + return { queryKey, queryFn, ...queryOptions } as UseQueryOptions< + Awaited>, + TError, + TData + > & { queryKey: DataTag }; +}; + +export type ListModelsApiV1ModelRegistryModelsGetQueryResult = NonNullable< + Awaited> +>; +export type ListModelsApiV1ModelRegistryModelsGetQueryError = + HTTPValidationError; + +export function useListModelsApiV1ModelRegistryModelsGet< + TData = Awaited>, + TError = HTTPValidationError, +>( + params: undefined | ListModelsApiV1ModelRegistryModelsGetParams, + options: { + query: Partial< + UseQueryOptions< + Awaited>, + TError, + TData + > + > & + Pick< + DefinedInitialDataOptions< + Awaited>, + TError, + Awaited> + >, + "initialData" + >; + request?: SecondParameter; + }, + queryClient?: QueryClient, +): DefinedUseQueryResult & { + queryKey: DataTag; +}; +export function useListModelsApiV1ModelRegistryModelsGet< + TData = Awaited>, + TError = HTTPValidationError, +>( + params?: ListModelsApiV1ModelRegistryModelsGetParams, + options?: { + query?: Partial< + UseQueryOptions< + Awaited>, + TError, + TData + > + > & + Pick< + UndefinedInitialDataOptions< + Awaited>, + TError, + Awaited> + >, + "initialData" + >; + request?: SecondParameter; + }, + queryClient?: QueryClient, +): UseQueryResult & { + queryKey: DataTag; +}; +export function useListModelsApiV1ModelRegistryModelsGet< + TData = Awaited>, + TError = HTTPValidationError, +>( + params?: ListModelsApiV1ModelRegistryModelsGetParams, + options?: { + query?: Partial< + UseQueryOptions< + Awaited>, + TError, + TData + > + >; + request?: SecondParameter; + }, + queryClient?: QueryClient, +): UseQueryResult & { + queryKey: DataTag; +}; +/** + * @summary List Models + */ + +export function useListModelsApiV1ModelRegistryModelsGet< + TData = Awaited>, + TError = HTTPValidationError, +>( + params?: ListModelsApiV1ModelRegistryModelsGetParams, + options?: { + query?: Partial< + UseQueryOptions< + Awaited>, + TError, + TData + > + >; + request?: SecondParameter; + }, + queryClient?: QueryClient, +): UseQueryResult & { + queryKey: DataTag; +} { + const queryOptions = getListModelsApiV1ModelRegistryModelsGetQueryOptions( + params, + options, + ); + + const query = useQuery(queryOptions, queryClient) as UseQueryResult< + TData, + TError + > & { queryKey: DataTag }; + + return { ...query, queryKey: queryOptions.queryKey }; +} + +/** + * Create a model catalog entry and sync gateway config. + * @summary Create Model + */ +export type createModelApiV1ModelRegistryModelsPostResponse200 = { + data: LlmModelRead; + status: 200; +}; + +export type createModelApiV1ModelRegistryModelsPostResponse422 = { + data: HTTPValidationError; + status: 422; +}; + +export type createModelApiV1ModelRegistryModelsPostResponseSuccess = + createModelApiV1ModelRegistryModelsPostResponse200 & { + headers: Headers; + }; +export type createModelApiV1ModelRegistryModelsPostResponseError = + createModelApiV1ModelRegistryModelsPostResponse422 & { + headers: Headers; + }; + +export type createModelApiV1ModelRegistryModelsPostResponse = + | createModelApiV1ModelRegistryModelsPostResponseSuccess + | createModelApiV1ModelRegistryModelsPostResponseError; + +export const getCreateModelApiV1ModelRegistryModelsPostUrl = () => { + return `/api/v1/model-registry/models`; +}; + +export const createModelApiV1ModelRegistryModelsPost = async ( + llmModelCreate: LlmModelCreate, + options?: RequestInit, +): Promise => { + return customFetch( + getCreateModelApiV1ModelRegistryModelsPostUrl(), + { + ...options, + method: "POST", + headers: { "Content-Type": "application/json", ...options?.headers }, + body: JSON.stringify(llmModelCreate), + }, + ); +}; + +export const getCreateModelApiV1ModelRegistryModelsPostMutationOptions = < + TError = HTTPValidationError, + TContext = unknown, +>(options?: { + mutation?: UseMutationOptions< + Awaited>, + TError, + { data: LlmModelCreate }, + TContext + >; + request?: SecondParameter; +}): UseMutationOptions< + Awaited>, + TError, + { data: LlmModelCreate }, + TContext +> => { + const mutationKey = ["createModelApiV1ModelRegistryModelsPost"]; + const { mutation: mutationOptions, request: requestOptions } = options + ? options.mutation && + "mutationKey" in options.mutation && + options.mutation.mutationKey + ? options + : { ...options, mutation: { ...options.mutation, mutationKey } } + : { mutation: { mutationKey }, request: undefined }; + + const mutationFn: MutationFunction< + Awaited>, + { data: LlmModelCreate } + > = (props) => { + const { data } = props ?? {}; + + return createModelApiV1ModelRegistryModelsPost(data, requestOptions); + }; + + return { mutationFn, ...mutationOptions }; +}; + +export type CreateModelApiV1ModelRegistryModelsPostMutationResult = NonNullable< + Awaited> +>; +export type CreateModelApiV1ModelRegistryModelsPostMutationBody = + LlmModelCreate; +export type CreateModelApiV1ModelRegistryModelsPostMutationError = + HTTPValidationError; + +/** + * @summary Create Model + */ +export const useCreateModelApiV1ModelRegistryModelsPost = < + TError = HTTPValidationError, + TContext = unknown, +>( + options?: { + mutation?: UseMutationOptions< + Awaited>, + TError, + { data: LlmModelCreate }, + TContext + >; + request?: SecondParameter; + }, + queryClient?: QueryClient, +): UseMutationResult< + Awaited>, + TError, + { data: LlmModelCreate }, + TContext +> => { + return useMutation( + getCreateModelApiV1ModelRegistryModelsPostMutationOptions(options), + queryClient, + ); +}; +/** + * Patch a model catalog entry and sync gateway config. + * @summary Update Model + */ +export type updateModelApiV1ModelRegistryModelsModelIdPatchResponse200 = { + data: LlmModelRead; + status: 200; +}; + +export type updateModelApiV1ModelRegistryModelsModelIdPatchResponse422 = { + data: HTTPValidationError; + status: 422; +}; + +export type updateModelApiV1ModelRegistryModelsModelIdPatchResponseSuccess = + updateModelApiV1ModelRegistryModelsModelIdPatchResponse200 & { + headers: Headers; + }; +export type updateModelApiV1ModelRegistryModelsModelIdPatchResponseError = + updateModelApiV1ModelRegistryModelsModelIdPatchResponse422 & { + headers: Headers; + }; + +export type updateModelApiV1ModelRegistryModelsModelIdPatchResponse = + | updateModelApiV1ModelRegistryModelsModelIdPatchResponseSuccess + | updateModelApiV1ModelRegistryModelsModelIdPatchResponseError; + +export const getUpdateModelApiV1ModelRegistryModelsModelIdPatchUrl = ( + modelId: string, +) => { + return `/api/v1/model-registry/models/${modelId}`; +}; + +export const updateModelApiV1ModelRegistryModelsModelIdPatch = async ( + modelId: string, + llmModelUpdate: LlmModelUpdate, + options?: RequestInit, +): Promise => { + return customFetch( + getUpdateModelApiV1ModelRegistryModelsModelIdPatchUrl(modelId), + { + ...options, + method: "PATCH", + headers: { "Content-Type": "application/json", ...options?.headers }, + body: JSON.stringify(llmModelUpdate), + }, + ); +}; + +export const getUpdateModelApiV1ModelRegistryModelsModelIdPatchMutationOptions = + (options?: { + mutation?: UseMutationOptions< + Awaited< + ReturnType + >, + TError, + { modelId: string; data: LlmModelUpdate }, + TContext + >; + request?: SecondParameter; + }): UseMutationOptions< + Awaited>, + TError, + { modelId: string; data: LlmModelUpdate }, + TContext + > => { + const mutationKey = ["updateModelApiV1ModelRegistryModelsModelIdPatch"]; + const { mutation: mutationOptions, request: requestOptions } = options + ? options.mutation && + "mutationKey" in options.mutation && + options.mutation.mutationKey + ? options + : { ...options, mutation: { ...options.mutation, mutationKey } } + : { mutation: { mutationKey }, request: undefined }; + + const mutationFn: MutationFunction< + Awaited< + ReturnType + >, + { modelId: string; data: LlmModelUpdate } + > = (props) => { + const { modelId, data } = props ?? {}; + + return updateModelApiV1ModelRegistryModelsModelIdPatch( + modelId, + data, + requestOptions, + ); + }; + + return { mutationFn, ...mutationOptions }; + }; + +export type UpdateModelApiV1ModelRegistryModelsModelIdPatchMutationResult = + NonNullable< + Awaited> + >; +export type UpdateModelApiV1ModelRegistryModelsModelIdPatchMutationBody = + LlmModelUpdate; +export type UpdateModelApiV1ModelRegistryModelsModelIdPatchMutationError = + HTTPValidationError; + +/** + * @summary Update Model + */ +export const useUpdateModelApiV1ModelRegistryModelsModelIdPatch = < + TError = HTTPValidationError, + TContext = unknown, +>( + options?: { + mutation?: UseMutationOptions< + Awaited< + ReturnType + >, + TError, + { modelId: string; data: LlmModelUpdate }, + TContext + >; + request?: SecondParameter; + }, + queryClient?: QueryClient, +): UseMutationResult< + Awaited>, + TError, + { modelId: string; data: LlmModelUpdate }, + TContext +> => { + return useMutation( + getUpdateModelApiV1ModelRegistryModelsModelIdPatchMutationOptions(options), + queryClient, + ); +}; +/** + * Delete a model catalog entry and sync gateway config. + * @summary Delete Model + */ +export type deleteModelApiV1ModelRegistryModelsModelIdDeleteResponse200 = { + data: OkResponse; + status: 200; +}; + +export type deleteModelApiV1ModelRegistryModelsModelIdDeleteResponse422 = { + data: HTTPValidationError; + status: 422; +}; + +export type deleteModelApiV1ModelRegistryModelsModelIdDeleteResponseSuccess = + deleteModelApiV1ModelRegistryModelsModelIdDeleteResponse200 & { + headers: Headers; + }; +export type deleteModelApiV1ModelRegistryModelsModelIdDeleteResponseError = + deleteModelApiV1ModelRegistryModelsModelIdDeleteResponse422 & { + headers: Headers; + }; + +export type deleteModelApiV1ModelRegistryModelsModelIdDeleteResponse = + | deleteModelApiV1ModelRegistryModelsModelIdDeleteResponseSuccess + | deleteModelApiV1ModelRegistryModelsModelIdDeleteResponseError; + +export const getDeleteModelApiV1ModelRegistryModelsModelIdDeleteUrl = ( + modelId: string, +) => { + return `/api/v1/model-registry/models/${modelId}`; +}; + +export const deleteModelApiV1ModelRegistryModelsModelIdDelete = async ( + modelId: string, + options?: RequestInit, +): Promise => { + return customFetch( + getDeleteModelApiV1ModelRegistryModelsModelIdDeleteUrl(modelId), + { + ...options, + method: "DELETE", + }, + ); +}; + +export const getDeleteModelApiV1ModelRegistryModelsModelIdDeleteMutationOptions = + (options?: { + mutation?: UseMutationOptions< + Awaited< + ReturnType + >, + TError, + { modelId: string }, + TContext + >; + request?: SecondParameter; + }): UseMutationOptions< + Awaited< + ReturnType + >, + TError, + { modelId: string }, + TContext + > => { + const mutationKey = ["deleteModelApiV1ModelRegistryModelsModelIdDelete"]; + const { mutation: mutationOptions, request: requestOptions } = options + ? options.mutation && + "mutationKey" in options.mutation && + options.mutation.mutationKey + ? options + : { ...options, mutation: { ...options.mutation, mutationKey } } + : { mutation: { mutationKey }, request: undefined }; + + const mutationFn: MutationFunction< + Awaited< + ReturnType + >, + { modelId: string } + > = (props) => { + const { modelId } = props ?? {}; + + return deleteModelApiV1ModelRegistryModelsModelIdDelete( + modelId, + requestOptions, + ); + }; + + return { mutationFn, ...mutationOptions }; + }; + +export type DeleteModelApiV1ModelRegistryModelsModelIdDeleteMutationResult = + NonNullable< + Awaited> + >; + +export type DeleteModelApiV1ModelRegistryModelsModelIdDeleteMutationError = + HTTPValidationError; + +/** + * @summary Delete Model + */ +export const useDeleteModelApiV1ModelRegistryModelsModelIdDelete = < + TError = HTTPValidationError, + TContext = unknown, +>( + options?: { + mutation?: UseMutationOptions< + Awaited< + ReturnType + >, + TError, + { modelId: string }, + TContext + >; + request?: SecondParameter; + }, + queryClient?: QueryClient, +): UseMutationResult< + Awaited>, + TError, + { modelId: string }, + TContext +> => { + return useMutation( + getDeleteModelApiV1ModelRegistryModelsModelIdDeleteMutationOptions(options), + queryClient, + ); +}; +/** + * Push provider auth + model catalog + agent model links to a gateway. + * @summary Sync Gateway Models + */ +export type syncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPostResponse200 = + { + data: GatewayModelSyncResult; + status: 200; + }; + +export type syncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPostResponse422 = + { + data: HTTPValidationError; + status: 422; + }; + +export type syncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPostResponseSuccess = + syncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPostResponse200 & { + headers: Headers; + }; +export type syncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPostResponseError = + syncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPostResponse422 & { + headers: Headers; + }; + +export type syncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPostResponse = + + | syncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPostResponseSuccess + | syncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPostResponseError; + +export const getSyncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPostUrl = + (gatewayId: string) => { + return `/api/v1/model-registry/gateways/${gatewayId}/sync`; + }; + +export const syncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPost = + async ( + gatewayId: string, + options?: RequestInit, + ): Promise => { + return customFetch( + getSyncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPostUrl( + gatewayId, + ), + { + ...options, + method: "POST", + }, + ); + }; + +export const getSyncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPostMutationOptions = + (options?: { + mutation?: UseMutationOptions< + Awaited< + ReturnType< + typeof syncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPost + > + >, + TError, + { gatewayId: string }, + TContext + >; + request?: SecondParameter; + }): UseMutationOptions< + Awaited< + ReturnType< + typeof syncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPost + > + >, + TError, + { gatewayId: string }, + TContext + > => { + const mutationKey = [ + "syncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPost", + ]; + const { mutation: mutationOptions, request: requestOptions } = options + ? options.mutation && + "mutationKey" in options.mutation && + options.mutation.mutationKey + ? options + : { ...options, mutation: { ...options.mutation, mutationKey } } + : { mutation: { mutationKey }, request: undefined }; + + const mutationFn: MutationFunction< + Awaited< + ReturnType< + typeof syncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPost + > + >, + { gatewayId: string } + > = (props) => { + const { gatewayId } = props ?? {}; + + return syncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPost( + gatewayId, + requestOptions, + ); + }; + + return { mutationFn, ...mutationOptions }; + }; + +export type SyncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPostMutationResult = + NonNullable< + Awaited< + ReturnType< + typeof syncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPost + > + > + >; + +export type SyncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPostMutationError = + HTTPValidationError; + +/** + * @summary Sync Gateway Models + */ +export const useSyncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPost = < + TError = HTTPValidationError, + TContext = unknown, +>( + options?: { + mutation?: UseMutationOptions< + Awaited< + ReturnType< + typeof syncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPost + > + >, + TError, + { gatewayId: string }, + TContext + >; + request?: SecondParameter; + }, + queryClient?: QueryClient, +): UseMutationResult< + Awaited< + ReturnType< + typeof syncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPost + > + >, + TError, + { gatewayId: string }, + TContext +> => { + return useMutation( + getSyncGatewayModelsApiV1ModelRegistryGatewaysGatewayIdSyncPostMutationOptions( + options, + ), + queryClient, + ); +}; diff --git a/frontend/src/api/generated/model/agentCreate.ts b/frontend/src/api/generated/model/agentCreate.ts index 17a9e768..3c9b1f35 100644 --- a/frontend/src/api/generated/model/agentCreate.ts +++ b/frontend/src/api/generated/model/agentCreate.ts @@ -16,6 +16,8 @@ export interface AgentCreate { name: string; status?: string; heartbeat_config?: AgentCreateHeartbeatConfig; + primary_model_id?: string | null; + fallback_model_ids?: string[] | null; identity_profile?: AgentCreateIdentityProfile; identity_template?: string | null; soul_template?: string | null; diff --git a/frontend/src/api/generated/model/agentRead.ts b/frontend/src/api/generated/model/agentRead.ts index e7da99ff..2f3c4d5f 100644 --- a/frontend/src/api/generated/model/agentRead.ts +++ b/frontend/src/api/generated/model/agentRead.ts @@ -16,6 +16,8 @@ export interface AgentRead { name: string; status?: string; heartbeat_config?: AgentReadHeartbeatConfig; + primary_model_id?: string | null; + fallback_model_ids?: string[] | null; identity_profile?: AgentReadIdentityProfile; identity_template?: string | null; soul_template?: string | null; diff --git a/frontend/src/api/generated/model/agentUpdate.ts b/frontend/src/api/generated/model/agentUpdate.ts index c9a6a1f1..2269e9dc 100644 --- a/frontend/src/api/generated/model/agentUpdate.ts +++ b/frontend/src/api/generated/model/agentUpdate.ts @@ -16,6 +16,8 @@ export interface AgentUpdate { name?: string | null; status?: string | null; heartbeat_config?: AgentUpdateHeartbeatConfig; + primary_model_id?: string | null; + fallback_model_ids?: string[] | null; identity_profile?: AgentUpdateIdentityProfile; identity_template?: string | null; soul_template?: string | null; diff --git a/frontend/src/api/generated/model/approvalCreate.ts b/frontend/src/api/generated/model/approvalCreate.ts index 6f8c9afc..b5d3937e 100644 --- a/frontend/src/api/generated/model/approvalCreate.ts +++ b/frontend/src/api/generated/model/approvalCreate.ts @@ -14,6 +14,7 @@ import type { ApprovalCreateStatus } from "./approvalCreateStatus"; export interface ApprovalCreate { action_type: string; task_id?: string | null; + task_ids?: string[]; payload?: ApprovalCreatePayload; confidence: number; rubric_scores?: ApprovalCreateRubricScores; diff --git a/frontend/src/api/generated/model/approvalRead.ts b/frontend/src/api/generated/model/approvalRead.ts index 8740d689..9f99100b 100644 --- a/frontend/src/api/generated/model/approvalRead.ts +++ b/frontend/src/api/generated/model/approvalRead.ts @@ -14,6 +14,7 @@ import type { ApprovalReadStatus } from "./approvalReadStatus"; export interface ApprovalRead { action_type: string; task_id?: string | null; + task_ids?: string[]; payload?: ApprovalReadPayload; confidence: number; rubric_scores?: ApprovalReadRubricScores; diff --git a/frontend/src/api/generated/model/gatewayModelSyncResult.ts b/frontend/src/api/generated/model/gatewayModelSyncResult.ts new file mode 100644 index 00000000..9aefc2d2 --- /dev/null +++ b/frontend/src/api/generated/model/gatewayModelSyncResult.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.2.0 🍺 + * Do not edit manually. + * Mission Control API + * OpenAPI spec version: 0.1.0 + */ + +/** + * Summary of model/provider config sync operations for a gateway. + */ +export interface GatewayModelSyncResult { + gateway_id: string; + provider_auth_patched: number; + model_catalog_patched: number; + agent_models_patched: number; + sessions_patched: number; + errors?: string[]; +} diff --git a/frontend/src/api/generated/model/index.ts b/frontend/src/api/generated/model/index.ts index 0261ddaa..4aafd67d 100644 --- a/frontend/src/api/generated/model/index.ts +++ b/frontend/src/api/generated/model/index.ts @@ -101,6 +101,7 @@ export * from "./gatewayLeadMessageRequestKind"; export * from "./gatewayLeadMessageResponse"; export * from "./gatewayMainAskUserRequest"; export * from "./gatewayMainAskUserResponse"; +export * from "./gatewayModelSyncResult"; export * from "./gatewayRead"; export * from "./gatewayResolveQuery"; export * from "./gatewaySessionHistoryResponse"; @@ -152,8 +153,10 @@ export * from "./listBoardsApiV1AgentBoardsGetParams"; export * from "./listBoardsApiV1BoardsGetParams"; export * from "./listGatewaysApiV1GatewaysGetParams"; export * from "./listGatewaySessionsApiV1GatewaysSessionsGetParams"; +export * from "./listModelsApiV1ModelRegistryModelsGetParams"; export * from "./listOrgInvitesApiV1OrganizationsMeInvitesGetParams"; export * from "./listOrgMembersApiV1OrganizationsMeMembersGetParams"; +export * from "./listProviderAuthApiV1ModelRegistryProviderAuthGetParams"; export * from "./listSessionsApiV1GatewaySessionsGet200"; export * from "./listSessionsApiV1GatewaySessionsGetParams"; export * from "./listTaskCommentFeedApiV1ActivityTaskCommentsGetParams"; @@ -161,6 +164,15 @@ export * from "./listTaskCommentsApiV1AgentBoardsBoardIdTasksTaskIdCommentsGetPa export * from "./listTaskCommentsApiV1BoardsBoardIdTasksTaskIdCommentsGetParams"; export * from "./listTasksApiV1AgentBoardsBoardIdTasksGetParams"; export * from "./listTasksApiV1BoardsBoardIdTasksGetParams"; +export * from "./llmModelCreate"; +export * from "./llmModelCreateSettings"; +export * from "./llmModelRead"; +export * from "./llmModelReadSettings"; +export * from "./llmModelUpdate"; +export * from "./llmModelUpdateSettings"; +export * from "./llmProviderAuthCreate"; +export * from "./llmProviderAuthRead"; +export * from "./llmProviderAuthUpdate"; export * from "./okResponse"; export * from "./organizationActiveUpdate"; export * from "./organizationBoardAccessRead"; diff --git a/frontend/src/api/generated/model/listModelsApiV1ModelRegistryModelsGetParams.ts b/frontend/src/api/generated/model/listModelsApiV1ModelRegistryModelsGetParams.ts new file mode 100644 index 00000000..7aa7a943 --- /dev/null +++ b/frontend/src/api/generated/model/listModelsApiV1ModelRegistryModelsGetParams.ts @@ -0,0 +1,10 @@ +/** + * Generated by orval v8.2.0 🍺 + * Do not edit manually. + * Mission Control API + * OpenAPI spec version: 0.1.0 + */ + +export type ListModelsApiV1ModelRegistryModelsGetParams = { + gateway_id?: string | null; +}; diff --git a/frontend/src/api/generated/model/listProviderAuthApiV1ModelRegistryProviderAuthGetParams.ts b/frontend/src/api/generated/model/listProviderAuthApiV1ModelRegistryProviderAuthGetParams.ts new file mode 100644 index 00000000..180e97a8 --- /dev/null +++ b/frontend/src/api/generated/model/listProviderAuthApiV1ModelRegistryProviderAuthGetParams.ts @@ -0,0 +1,10 @@ +/** + * Generated by orval v8.2.0 🍺 + * Do not edit manually. + * Mission Control API + * OpenAPI spec version: 0.1.0 + */ + +export type ListProviderAuthApiV1ModelRegistryProviderAuthGetParams = { + gateway_id?: string | null; +}; diff --git a/frontend/src/api/generated/model/llmModelCreate.ts b/frontend/src/api/generated/model/llmModelCreate.ts new file mode 100644 index 00000000..3b49554c --- /dev/null +++ b/frontend/src/api/generated/model/llmModelCreate.ts @@ -0,0 +1,21 @@ +/** + * Generated by orval v8.2.0 🍺 + * Do not edit manually. + * Mission Control API + * OpenAPI spec version: 0.1.0 + */ +import type { LlmModelCreateSettings } from "./llmModelCreateSettings"; + +/** + * Payload used to create a model catalog entry. + */ +export interface LlmModelCreate { + gateway_id: string; + /** @minLength 1 */ + provider: string; + /** @minLength 1 */ + model_id: string; + /** @minLength 1 */ + display_name: string; + settings?: LlmModelCreateSettings; +} diff --git a/frontend/src/api/generated/model/llmModelCreateSettings.ts b/frontend/src/api/generated/model/llmModelCreateSettings.ts new file mode 100644 index 00000000..21ddb6b8 --- /dev/null +++ b/frontend/src/api/generated/model/llmModelCreateSettings.ts @@ -0,0 +1,8 @@ +/** + * Generated by orval v8.2.0 🍺 + * Do not edit manually. + * Mission Control API + * OpenAPI spec version: 0.1.0 + */ + +export type LlmModelCreateSettings = { [key: string]: unknown } | null; diff --git a/frontend/src/api/generated/model/llmModelRead.ts b/frontend/src/api/generated/model/llmModelRead.ts new file mode 100644 index 00000000..90ef5210 --- /dev/null +++ b/frontend/src/api/generated/model/llmModelRead.ts @@ -0,0 +1,22 @@ +/** + * Generated by orval v8.2.0 🍺 + * Do not edit manually. + * Mission Control API + * OpenAPI spec version: 0.1.0 + */ +import type { LlmModelReadSettings } from "./llmModelReadSettings"; + +/** + * Public model catalog entry payload. + */ +export interface LlmModelRead { + id: string; + organization_id: string; + gateway_id: string; + provider: string; + model_id: string; + display_name: string; + settings?: LlmModelReadSettings; + created_at: string; + updated_at: string; +} diff --git a/frontend/src/api/generated/model/llmModelReadSettings.ts b/frontend/src/api/generated/model/llmModelReadSettings.ts new file mode 100644 index 00000000..ec8418aa --- /dev/null +++ b/frontend/src/api/generated/model/llmModelReadSettings.ts @@ -0,0 +1,8 @@ +/** + * Generated by orval v8.2.0 🍺 + * Do not edit manually. + * Mission Control API + * OpenAPI spec version: 0.1.0 + */ + +export type LlmModelReadSettings = { [key: string]: unknown } | null; diff --git a/frontend/src/api/generated/model/llmModelUpdate.ts b/frontend/src/api/generated/model/llmModelUpdate.ts new file mode 100644 index 00000000..1a32ad4f --- /dev/null +++ b/frontend/src/api/generated/model/llmModelUpdate.ts @@ -0,0 +1,17 @@ +/** + * Generated by orval v8.2.0 🍺 + * Do not edit manually. + * Mission Control API + * OpenAPI spec version: 0.1.0 + */ +import type { LlmModelUpdateSettings } from "./llmModelUpdateSettings"; + +/** + * Payload used to patch an existing model catalog entry. + */ +export interface LlmModelUpdate { + provider?: string | null; + model_id?: string | null; + display_name?: string | null; + settings?: LlmModelUpdateSettings; +} diff --git a/frontend/src/api/generated/model/llmModelUpdateSettings.ts b/frontend/src/api/generated/model/llmModelUpdateSettings.ts new file mode 100644 index 00000000..e5082bfe --- /dev/null +++ b/frontend/src/api/generated/model/llmModelUpdateSettings.ts @@ -0,0 +1,8 @@ +/** + * Generated by orval v8.2.0 🍺 + * Do not edit manually. + * Mission Control API + * OpenAPI spec version: 0.1.0 + */ + +export type LlmModelUpdateSettings = { [key: string]: unknown } | null; diff --git a/frontend/src/api/generated/model/llmProviderAuthCreate.ts b/frontend/src/api/generated/model/llmProviderAuthCreate.ts new file mode 100644 index 00000000..ebe96699 --- /dev/null +++ b/frontend/src/api/generated/model/llmProviderAuthCreate.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.2.0 🍺 + * Do not edit manually. + * Mission Control API + * OpenAPI spec version: 0.1.0 + */ + +/** + * Payload used to create a provider auth record. + */ +export interface LlmProviderAuthCreate { + gateway_id: string; + /** @minLength 1 */ + provider: string; + config_path?: string | null; + /** @minLength 1 */ + secret: string; +} diff --git a/frontend/src/api/generated/model/llmProviderAuthRead.ts b/frontend/src/api/generated/model/llmProviderAuthRead.ts new file mode 100644 index 00000000..05bc67d2 --- /dev/null +++ b/frontend/src/api/generated/model/llmProviderAuthRead.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.2.0 🍺 + * Do not edit manually. + * Mission Control API + * OpenAPI spec version: 0.1.0 + */ + +/** + * Public provider auth payload (secret value is never returned). + */ +export interface LlmProviderAuthRead { + id: string; + organization_id: string; + gateway_id: string; + provider: string; + config_path: string; + has_secret?: boolean; + created_at: string; + updated_at: string; +} diff --git a/frontend/src/api/generated/model/llmProviderAuthUpdate.ts b/frontend/src/api/generated/model/llmProviderAuthUpdate.ts new file mode 100644 index 00000000..0f94394c --- /dev/null +++ b/frontend/src/api/generated/model/llmProviderAuthUpdate.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.2.0 🍺 + * Do not edit manually. + * Mission Control API + * OpenAPI spec version: 0.1.0 + */ + +/** + * Payload used to patch an existing provider auth record. + */ +export interface LlmProviderAuthUpdate { + provider?: string | null; + config_path?: string | null; + secret?: string | null; +} diff --git a/frontend/src/app/agents/[agentId]/edit/page.tsx b/frontend/src/app/agents/[agentId]/edit/page.tsx index fee9303a..e5fd6072 100644 --- a/frontend/src/app/agents/[agentId]/edit/page.tsx +++ b/frontend/src/app/agents/[agentId]/edit/page.tsx @@ -17,6 +17,10 @@ import { type listBoardsApiV1BoardsGetResponse, useListBoardsApiV1BoardsGet, } from "@/api/generated/boards/boards"; +import { + type listModelsApiV1ModelRegistryModelsGetResponse, + useListModelsApiV1ModelRegistryModelsGet, +} from "@/api/generated/model-registry/model-registry"; import type { AgentRead, AgentUpdate, BoardRead } from "@/api/generated/model"; import { DashboardPageLayout } from "@/components/templates/DashboardPageLayout"; import { Button } from "@/components/ui/button"; @@ -118,6 +122,12 @@ export default function EditAgentPage() { const [heartbeatTarget, setHeartbeatTarget] = useState( undefined, ); + const [primaryModelId, setPrimaryModelId] = useState( + undefined, + ); + const [fallbackModelIds, setFallbackModelIds] = useState( + undefined, + ); const [identityProfile, setIdentityProfile] = useState< IdentityProfile | undefined >(undefined); @@ -136,6 +146,16 @@ export default function EditAgentPage() { retry: false, }, }); + const modelsQuery = useListModelsApiV1ModelRegistryModelsGet< + listModelsApiV1ModelRegistryModelsGetResponse, + ApiError + >(undefined, { + query: { + enabled: Boolean(isSignedIn), + refetchOnMount: "always", + retry: false, + }, + }); const agentQuery = useGetAgentApiV1AgentsAgentIdGet< getAgentApiV1AgentsAgentIdGetResponse, @@ -165,6 +185,10 @@ export default function EditAgentPage() { if (boardsQuery.data?.status !== 200) return []; return boardsQuery.data.data.items ?? []; }, [boardsQuery.data]); + const models = useMemo(() => { + if (modelsQuery.data?.status !== 200) return []; + return modelsQuery.data.data; + }, [modelsQuery.data]); const loadedAgent: AgentRead | null = agentQuery.data?.status === 200 ? agentQuery.data.data : null; @@ -201,17 +225,30 @@ export default function EditAgentPage() { const loadedSoulTemplate = useMemo(() => { return loadedAgent?.soul_template?.trim() || DEFAULT_SOUL_TEMPLATE; }, [loadedAgent?.soul_template]); + const loadedFallbackModelIds = useMemo( + () => loadedAgent?.fallback_model_ids ?? [], + [loadedAgent?.fallback_model_ids], + ); const isLoading = - boardsQuery.isLoading || agentQuery.isLoading || updateMutation.isPending; + boardsQuery.isLoading || + modelsQuery.isLoading || + agentQuery.isLoading || + updateMutation.isPending; const errorMessage = - error ?? agentQuery.error?.message ?? boardsQuery.error?.message ?? null; + error ?? + agentQuery.error?.message ?? + boardsQuery.error?.message ?? + modelsQuery.error?.message ?? + null; const resolvedName = name ?? loadedAgent?.name ?? ""; const resolvedIsGatewayMain = isGatewayMain ?? Boolean(loadedAgent?.is_gateway_main); const resolvedHeartbeatEvery = heartbeatEvery ?? loadedHeartbeat.every; const resolvedHeartbeatTarget = heartbeatTarget ?? loadedHeartbeat.target; + const resolvedPrimaryModelId = primaryModelId ?? loadedAgent?.primary_model_id ?? ""; + const resolvedFallbackModelIds = fallbackModelIds ?? loadedFallbackModelIds; const resolvedIdentityProfile = identityProfile ?? loadedIdentityProfile; const resolvedSoulTemplate = soulTemplate ?? loadedSoulTemplate; @@ -219,6 +256,40 @@ export default function EditAgentPage() { if (resolvedIsGatewayMain) return boardId ?? ""; return boardId ?? loadedAgent?.board_id ?? boards[0]?.id ?? ""; }, [boardId, boards, loadedAgent?.board_id, resolvedIsGatewayMain]); + const targetGatewayId = useMemo(() => { + if (!loadedAgent) return null; + if (resolvedBoardId) { + const selectedBoard = boards.find((board) => board.id === resolvedBoardId); + if (selectedBoard?.gateway_id) { + return selectedBoard.gateway_id; + } + } + return loadedAgent.gateway_id; + }, [boards, loadedAgent, resolvedBoardId]); + const availableModels = useMemo( + () => models.filter((model) => model.gateway_id === targetGatewayId), + [models, targetGatewayId], + ); + const modelOptions = useMemo( + () => + availableModels.map((model) => ({ + value: model.id, + label: `${model.display_name} (${model.model_id})`, + })), + [availableModels], + ); + const availableModelIds = useMemo( + () => new Set(availableModels.map((model) => model.id)), + [availableModels], + ); + const effectivePrimaryModelId = availableModelIds.has(resolvedPrimaryModelId) + ? resolvedPrimaryModelId + : ""; + const effectiveFallbackModelIds = resolvedFallbackModelIds.filter( + (modelIdValue) => + modelIdValue !== effectivePrimaryModelId && + availableModelIds.has(modelIdValue), + ); const handleSubmit = (event: React.FormEvent) => { event.preventDefault(); @@ -250,6 +321,11 @@ export default function EditAgentPage() { typeof loadedAgent.heartbeat_config === "object" ? (loadedAgent.heartbeat_config as Record) : {}; + const normalizedFallbackModelIds = effectiveFallbackModelIds.filter( + (modelIdValue) => + modelIdValue !== effectivePrimaryModelId && + availableModelIds.has(modelIdValue), + ); const payload: AgentUpdate = { name: trimmed, @@ -266,6 +342,14 @@ export default function EditAgentPage() { loadedAgent.identity_profile, resolvedIdentityProfile, ) as unknown as Record | null, + primary_model_id: + effectivePrimaryModelId && availableModelIds.has(effectivePrimaryModelId) + ? effectivePrimaryModelId + : null, + fallback_model_ids: + normalizedFallbackModelIds.length > 0 + ? normalizedFallbackModelIds + : null, soul_template: resolvedSoulTemplate.trim() || null, }; if (!resolvedIsGatewayMain) { @@ -469,6 +553,78 @@ export default function EditAgentPage() { +
+

+ LLM routing +

+
+
+ + + {modelOptions.length === 0 ? ( +

+ No models found for this agent's gateway. Configure models + in the Models page. +

+ ) : null} +
+
+ + {modelOptions.length === 0 ? ( +

No fallback models yet.

+ ) : ( +
+ {modelOptions.map((option) => ( + + ))} +
+ )} +
+
+
+

Schedule & notifications diff --git a/frontend/src/app/agents/new/page.tsx b/frontend/src/app/agents/new/page.tsx index 7f6af777..82df312f 100644 --- a/frontend/src/app/agents/new/page.tsx +++ b/frontend/src/app/agents/new/page.tsx @@ -2,7 +2,7 @@ export const dynamic = "force-dynamic"; -import { useState } from "react"; +import { useMemo, useState } from "react"; import { useRouter } from "next/navigation"; import { useAuth } from "@/auth/clerk"; @@ -12,9 +12,13 @@ import { type listBoardsApiV1BoardsGetResponse, useListBoardsApiV1BoardsGet, } from "@/api/generated/boards/boards"; +import { + type listModelsApiV1ModelRegistryModelsGetResponse, + useListModelsApiV1ModelRegistryModelsGet, +} from "@/api/generated/model-registry/model-registry"; import { useCreateAgentApiV1AgentsPost } from "@/api/generated/agents/agents"; import { useOrganizationMembership } from "@/lib/use-organization-membership"; -import type { BoardRead } from "@/api/generated/model"; +import type { BoardRead, LlmModelRead } from "@/api/generated/model"; import { DashboardPageLayout } from "@/components/templates/DashboardPageLayout"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; @@ -86,6 +90,8 @@ export default function NewAgentPage() { const [boardId, setBoardId] = useState(""); const [heartbeatEvery, setHeartbeatEvery] = useState("10m"); const [heartbeatTarget, setHeartbeatTarget] = useState("none"); + const [primaryModelId, setPrimaryModelId] = useState(""); + const [fallbackModelIds, setFallbackModelIds] = useState([]); const [identityProfile, setIdentityProfile] = useState({ ...DEFAULT_IDENTITY_PROFILE, }); @@ -114,12 +120,48 @@ export default function NewAgentPage() { }, }, }); + const modelsQuery = useListModelsApiV1ModelRegistryModelsGet< + listModelsApiV1ModelRegistryModelsGetResponse, + ApiError + >(undefined, { + query: { + enabled: Boolean(isSignedIn && isAdmin), + refetchOnMount: "always", + }, + }); const boards = boardsQuery.data?.status === 200 ? (boardsQuery.data.data.items ?? []) : []; + const models: LlmModelRead[] = + modelsQuery.data?.status === 200 ? modelsQuery.data.data : []; const displayBoardId = boardId || boards[0]?.id || ""; + const selectedBoard = boards.find((board) => board.id === displayBoardId) ?? null; + const availableModels = models.filter( + (model) => model.gateway_id === selectedBoard?.gateway_id, + ); + const modelOptions = availableModels.map((model) => ({ + value: model.id, + label: `${model.display_name} (${model.model_id})`, + })); + const availableModelIds = useMemo( + () => new Set(availableModels.map((model) => model.id)), + [availableModels], + ); + const effectivePrimaryModelId = availableModelIds.has(primaryModelId) + ? primaryModelId + : ""; + const effectiveFallbackModelIds = fallbackModelIds.filter( + (modelId) => + modelId !== effectivePrimaryModelId && availableModelIds.has(modelId), + ); const isLoading = boardsQuery.isLoading || createAgentMutation.isPending; - const errorMessage = error ?? boardsQuery.error?.message ?? null; + const errorMessage = + error ?? boardsQuery.error?.message ?? modelsQuery.error?.message ?? null; + + const normalizedFallbackModelIds = fallbackModelIds.filter( + (modelId) => + modelId !== effectivePrimaryModelId && availableModelIds.has(modelId), + ); const handleSubmit = (event: React.FormEvent) => { event.preventDefault(); @@ -147,6 +189,11 @@ export default function NewAgentPage() { identity_profile: normalizeIdentityProfile( identityProfile, ) as unknown as Record | null, + primary_model_id: effectivePrimaryModelId || null, + fallback_model_ids: + normalizedFallbackModelIds.length > 0 + ? normalizedFallbackModelIds + : null, soul_template: soulTemplate.trim() || null, }, }); @@ -290,6 +337,74 @@ export default function NewAgentPage() {

+
+

+ LLM routing +

+
+
+ + + {availableModels.length === 0 ? ( +

+ No models found for this board's gateway. Configure models + in the Models page first. +

+ ) : null} +
+
+ + {modelOptions.length === 0 ? ( +

No fallback models yet.

+ ) : ( +
+ {modelOptions.map((option) => ( + + ))} +
+ )} +
+
+
+

Schedule & notifications diff --git a/frontend/src/app/models/_components/AgentRoutingEditPage.tsx b/frontend/src/app/models/_components/AgentRoutingEditPage.tsx new file mode 100644 index 00000000..34ec1282 --- /dev/null +++ b/frontend/src/app/models/_components/AgentRoutingEditPage.tsx @@ -0,0 +1,493 @@ +"use client"; + +import { useMemo, useState } from "react"; +import Link from "next/link"; +import { useRouter, useSearchParams } from "next/navigation"; + +import { useAuth } from "@/auth/clerk"; +import { useQueryClient } from "@tanstack/react-query"; + +import { ApiError } from "@/api/mutator"; +import { + type listAgentsApiV1AgentsGetResponse, + getListAgentsApiV1AgentsGetQueryKey, + useListAgentsApiV1AgentsGet, + useUpdateAgentApiV1AgentsAgentIdPatch, +} from "@/api/generated/agents/agents"; +import { + type listBoardsApiV1BoardsGetResponse, + useListBoardsApiV1BoardsGet, +} from "@/api/generated/boards/boards"; +import { + type listGatewaysApiV1GatewaysGetResponse, + useListGatewaysApiV1GatewaysGet, +} from "@/api/generated/gateways/gateways"; +import type { AgentRead, BoardRead, LlmModelRead } from "@/api/generated/model"; +import { + type listModelsApiV1ModelRegistryModelsGetResponse, + useListModelsApiV1ModelRegistryModelsGet, +} from "@/api/generated/model-registry/model-registry"; +import { DashboardPageLayout } from "@/components/templates/DashboardPageLayout"; +import { Badge } from "@/components/ui/badge"; +import { Button, buttonVariants } from "@/components/ui/button"; +import SearchableSelect, { + type SearchableSelectOption, +} from "@/components/ui/searchable-select"; +import { useOrganizationMembership } from "@/lib/use-organization-membership"; + +type AgentRoutingEditPageProps = { + agentId: string; +}; + +type RoutingStatus = "override" | "default" | "unconfigured"; + +const routingStatusLabel = (status: RoutingStatus): string => { + if (status === "override") return "Primary override"; + if (status === "default") return "Using default"; + return "No primary"; +}; + +const routingStatusVariant = ( + status: RoutingStatus, +): "success" | "accent" | "warning" => { + if (status === "override") return "success"; + if (status === "default") return "accent"; + return "warning"; +}; + +const agentRoleLabel = (agent: AgentRead): string | null => { + const role = agent.identity_profile?.role; + if (typeof role !== "string") return null; + const normalized = role.trim(); + return normalized || null; +}; + +const modelOptionLabel = (model: LlmModelRead): string => + `${model.display_name} (${model.model_id})`; + +const stringListsMatch = (left: string[], right: string[]): boolean => { + if (left.length !== right.length) return false; + for (let index = 0; index < left.length; index += 1) { + if (left[index] !== right[index]) return false; + } + return true; +}; + +const withGatewayQuery = (path: string, gatewayId: string): string => { + if (!gatewayId) return path; + return `${path}?gateway=${encodeURIComponent(gatewayId)}`; +}; + +export default function AgentRoutingEditPage({ agentId }: AgentRoutingEditPageProps) { + const { isSignedIn } = useAuth(); + const router = useRouter(); + const queryClient = useQueryClient(); + const searchParams = useSearchParams(); + const { isAdmin } = useOrganizationMembership(isSignedIn); + + const [primaryModelDraft, setPrimaryModelDraft] = useState(null); + const [fallbackModelDraft, setFallbackModelDraft] = useState(null); + const [error, setError] = useState(null); + + const agentsKey = getListAgentsApiV1AgentsGetQueryKey(); + + const agentsQuery = useListAgentsApiV1AgentsGet< + listAgentsApiV1AgentsGetResponse, + ApiError + >(undefined, { + query: { + enabled: Boolean(isSignedIn && isAdmin), + refetchOnMount: "always", + }, + }); + + const boardsQuery = useListBoardsApiV1BoardsGet< + listBoardsApiV1BoardsGetResponse, + ApiError + >(undefined, { + query: { + enabled: Boolean(isSignedIn && isAdmin), + refetchOnMount: "always", + }, + }); + + const gatewaysQuery = useListGatewaysApiV1GatewaysGet< + listGatewaysApiV1GatewaysGetResponse, + ApiError + >(undefined, { + query: { + enabled: Boolean(isSignedIn && isAdmin), + refetchOnMount: "always", + }, + }); + + const modelsQuery = useListModelsApiV1ModelRegistryModelsGet< + listModelsApiV1ModelRegistryModelsGetResponse, + ApiError + >(undefined, { + query: { + enabled: Boolean(isSignedIn && isAdmin), + refetchOnMount: "always", + }, + }); + + const updateAgentMutation = useUpdateAgentApiV1AgentsAgentIdPatch({ + mutation: { + onSuccess: async () => { + await queryClient.invalidateQueries({ queryKey: agentsKey }); + if (agent?.gateway_id) { + router.push(withGatewayQuery("/models/routing", agent.gateway_id)); + return; + } + router.push("/models/routing"); + }, + onError: (updateError) => { + setError(updateError.message || "Unable to save agent routing."); + }, + }, + }); + + const agents = useMemo(() => { + if (agentsQuery.data?.status !== 200) return []; + return agentsQuery.data.data.items ?? []; + }, [agentsQuery.data]); + + const boards = useMemo(() => { + if (boardsQuery.data?.status !== 200) return []; + return boardsQuery.data.data.items ?? []; + }, [boardsQuery.data]); + + const gateways = useMemo(() => { + if (gatewaysQuery.data?.status !== 200) return []; + return gatewaysQuery.data.data.items ?? []; + }, [gatewaysQuery.data]); + + const models = useMemo(() => { + if (modelsQuery.data?.status !== 200) return []; + return modelsQuery.data.data; + }, [modelsQuery.data]); + + const boardsById = useMemo(() => new Map(boards.map((board) => [board.id, board] as const)), [boards]); + const gatewaysById = useMemo( + () => new Map(gateways.map((gateway) => [gateway.id, gateway] as const)), + [gateways], + ); + + const agent = agents.find((item) => item.id === agentId) ?? null; + const agentBoard = agent?.board_id ? (boardsById.get(agent.board_id) ?? null) : null; + const modelsForGateway = agent?.gateway_id + ? models.filter((model) => model.gateway_id === agent.gateway_id) + : []; + const modelsById = new Map(modelsForGateway.map((item) => [item.id, item] as const)); + const availableModelIds = new Set(modelsForGateway.map((model) => model.id)); + const defaultPrimaryModel = modelsForGateway[0] ?? null; + const baselinePrimaryModelId = agent?.primary_model_id ?? ""; + const baselineFallbackModelIds = agent?.fallback_model_ids ?? []; + const primaryModelIdCandidate = primaryModelDraft ?? baselinePrimaryModelId; + const primaryModelId = availableModelIds.has(primaryModelIdCandidate) + ? primaryModelIdCandidate + : ""; + const fallbackModelIds = (() => { + const source = fallbackModelDraft ?? baselineFallbackModelIds; + return source.filter( + (modelIdValue, index, list) => + modelIdValue !== primaryModelId && + availableModelIds.has(modelIdValue) && + list.indexOf(modelIdValue) === index, + ); + })(); + + const selectedPrimary = primaryModelId ? (modelsById.get(primaryModelId) ?? null) : null; + + const effectivePrimaryModel = selectedPrimary ?? defaultPrimaryModel; + + const status: RoutingStatus = primaryModelId + ? "override" + : effectivePrimaryModel + ? "default" + : "unconfigured"; + + const selectedFallbackModels = fallbackModelIds + .map((id) => modelsById.get(id) ?? null) + .filter((model): model is LlmModelRead => model !== null); + + const modelOptions: SearchableSelectOption[] = modelsForGateway.map((model) => ({ + value: model.id, + label: modelOptionLabel(model), + })); + + const hasUnsavedChanges = (() => { + if (!agent) return false; + return ( + baselinePrimaryModelId !== primaryModelId || + !stringListsMatch(baselineFallbackModelIds, fallbackModelIds) + ); + })(); + + const handleSave = () => { + if (!agent) { + setError("Agent not found."); + return; + } + + if (primaryModelId && !availableModelIds.has(primaryModelId)) { + setError("Primary model must belong to this gateway catalog."); + return; + } + + setError(null); + updateAgentMutation.mutate({ + agentId: agent.id, + params: { force: true }, + data: { + primary_model_id: primaryModelId || null, + fallback_model_ids: fallbackModelIds.length > 0 ? fallbackModelIds : null, + }, + }); + }; + + const handleRevert = () => { + setError(null); + setPrimaryModelDraft(null); + setFallbackModelDraft(null); + }; + + const requestedGateway = searchParams.get("gateway")?.trim() ?? ""; + const backGatewayId = agent?.gateway_id ?? requestedGateway; + const gatewayName = agent?.gateway_id ? (gatewaysById.get(agent.gateway_id)?.name ?? null) : null; + + const pageError = + agentsQuery.error?.message ?? + boardsQuery.error?.message ?? + gatewaysQuery.error?.message ?? + modelsQuery.error?.message ?? + null; + + const missingAgent = + !agentsQuery.isLoading && + agentsQuery.data?.status === 200 && + !agent; + + return ( + +

+ {missingAgent ? ( +
+ Agent not found. +
+ ) : ( +
+
+

Agent details

+ {!agent ? ( +

Loading agent...

+ ) : ( + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Agent + + {agent.name} + +
Role{agentRoleLabel(agent) ?? "Unspecified"}
Board + {agent.board_id ? ( + + {agentBoard?.name ?? "Open board"} + + ) : ( + "Gateway main" + )} +
Gateway + {gatewayName ? `${gatewayName} (${agent.gateway_id})` : (agent.gateway_id || "Unknown gateway")} +
Status + {routingStatusLabel(status)} +
Effective primary + {effectivePrimaryModel ? modelOptionLabel(effectivePrimaryModel) : "None"} + {!selectedPrimary && effectivePrimaryModel ? " (inherited from default)" : ""} +
+ )} +
+ +
+
+

Routing assignment

+ {!agent ? ( + Loading + ) : hasUnsavedChanges ? ( + Unsaved changes + ) : ( + Saved + )} +
+

+ Primary override is optional. Empty primary inherits the gateway default. +

+ +
+
+

Primary model override

+ { + setPrimaryModelDraft(value); + setFallbackModelDraft((current) => { + const source = current ?? baselineFallbackModelIds; + return source.filter((item) => item !== value); + }); + }} + options={modelOptions} + placeholder="Use gateway default (no override)" + searchPlaceholder="Search models..." + emptyMessage="No matching models." + triggerClassName="w-full" + disabled={!agent || updateAgentMutation.isPending || modelOptions.length === 0} + /> +
+ +
+
+

+ Fallback models ({fallbackModelIds.length}) +

+ {selectedFallbackModels.length > 0 ? ( +
+ {selectedFallbackModels.map((model) => ( + + {model.display_name} + + ))} +
+ ) : null} +
+ + {modelOptions.length === 0 ? ( +

No catalog models available for this gateway yet.

+ ) : ( +
+ {modelOptions.map((option) => { + const checked = fallbackModelIds.includes(option.value); + const disabled = option.value === primaryModelId || !agent; + const model = modelsById.get(option.value) ?? null; + return ( + + ); + })} +
+ )} +
+
+ +
+ + + + + Back to routing table + +
+
+
+ )} + + {error ?

{error}

: null} + {pageError ?

{pageError}

: null} +
+ + ); +} diff --git a/frontend/src/app/models/_components/CatalogModelFormPage.tsx b/frontend/src/app/models/_components/CatalogModelFormPage.tsx new file mode 100644 index 00000000..f452cccb --- /dev/null +++ b/frontend/src/app/models/_components/CatalogModelFormPage.tsx @@ -0,0 +1,344 @@ +"use client"; + +import { useMemo, useState, type FormEvent } from "react"; +import Link from "next/link"; +import { useRouter, useSearchParams } from "next/navigation"; + +import { useAuth } from "@/auth/clerk"; +import { useQueryClient } from "@tanstack/react-query"; + +import { ApiError } from "@/api/mutator"; +import { + type listGatewaysApiV1GatewaysGetResponse, + useListGatewaysApiV1GatewaysGet, +} from "@/api/generated/gateways/gateways"; +import type { GatewayRead, LlmModelRead } from "@/api/generated/model"; +import { + type listModelsApiV1ModelRegistryModelsGetResponse, + getListModelsApiV1ModelRegistryModelsGetQueryKey, + useCreateModelApiV1ModelRegistryModelsPost, + useListModelsApiV1ModelRegistryModelsGet, + useUpdateModelApiV1ModelRegistryModelsModelIdPatch, +} from "@/api/generated/model-registry/model-registry"; +import { DashboardPageLayout } from "@/components/templates/DashboardPageLayout"; +import { Button, buttonVariants } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import SearchableSelect, { + type SearchableSelectOption, +} from "@/components/ui/searchable-select"; +import { Textarea } from "@/components/ui/textarea"; +import { useOrganizationMembership } from "@/lib/use-organization-membership"; + +type CatalogModelFormPageProps = + | { mode: "create" } + | { mode: "edit"; modelId: string }; + +const toGatewayOptions = (gateways: GatewayRead[]): SearchableSelectOption[] => + gateways.map((gateway) => ({ value: gateway.id, label: gateway.name })); + +const withGatewayQuery = (path: string, gatewayId: string): string => { + if (!gatewayId) return path; + return `${path}?gateway=${encodeURIComponent(gatewayId)}`; +}; + +export default function CatalogModelFormPage(props: CatalogModelFormPageProps) { + const { isSignedIn } = useAuth(); + const router = useRouter(); + const searchParams = useSearchParams(); + const queryClient = useQueryClient(); + const { isAdmin } = useOrganizationMembership(isSignedIn); + + const modelIdParam = props.mode === "edit" ? props.modelId : null; + + const [gatewayDraft, setGatewayDraft] = useState(null); + const [providerDraft, setProviderDraft] = useState(null); + const [modelIdDraft, setModelIdDraft] = useState(null); + const [displayNameDraft, setDisplayNameDraft] = useState(null); + const [settingsDraft, setSettingsDraft] = useState(null); + const [error, setError] = useState(null); + + const modelsKey = getListModelsApiV1ModelRegistryModelsGetQueryKey(); + + const gatewaysQuery = useListGatewaysApiV1GatewaysGet< + listGatewaysApiV1GatewaysGetResponse, + ApiError + >(undefined, { + query: { + enabled: Boolean(isSignedIn && isAdmin), + refetchOnMount: "always", + }, + }); + + const modelsQuery = useListModelsApiV1ModelRegistryModelsGet< + listModelsApiV1ModelRegistryModelsGetResponse, + ApiError + >(undefined, { + query: { + enabled: Boolean(isSignedIn && isAdmin), + refetchOnMount: "always", + }, + }); + + const createMutation = useCreateModelApiV1ModelRegistryModelsPost({ + mutation: { + onSuccess: async () => { + await queryClient.invalidateQueries({ queryKey: modelsKey }); + router.push(withGatewayQuery("/models/catalog", gatewayId)); + }, + onError: (err) => { + setError(err.message || "Unable to create model."); + }, + }, + }); + + const updateMutation = useUpdateModelApiV1ModelRegistryModelsModelIdPatch({ + mutation: { + onSuccess: async () => { + await queryClient.invalidateQueries({ queryKey: modelsKey }); + router.push(withGatewayQuery("/models/catalog", gatewayId)); + }, + onError: (err) => { + setError(err.message || "Unable to update model."); + }, + }, + }); + + const gateways = useMemo(() => { + if (gatewaysQuery.data?.status !== 200) return []; + return gatewaysQuery.data.data.items ?? []; + }, [gatewaysQuery.data]); + + const models = useMemo(() => { + if (modelsQuery.data?.status !== 200) return []; + return modelsQuery.data.data; + }, [modelsQuery.data]); + + const currentItem = useMemo(() => { + if (props.mode !== "edit" || !modelIdParam) return null; + return models.find((item) => item.id === modelIdParam) ?? null; + }, [modelIdParam, models, props.mode]); + + const gatewayOptions = useMemo(() => toGatewayOptions(gateways), [gateways]); + const requestedGateway = searchParams.get("gateway")?.trim() ?? ""; + const gatewayId = (() => { + if (gateways.length === 0) return ""; + if (props.mode === "edit" && currentItem?.gateway_id) { + return currentItem.gateway_id; + } + if (gatewayDraft && gateways.some((gateway) => gateway.id === gatewayDraft)) { + return gatewayDraft; + } + if (requestedGateway && gateways.some((gateway) => gateway.id === requestedGateway)) { + return requestedGateway; + } + return gateways[0].id; + })(); + + const provider = providerDraft ?? (props.mode === "edit" ? (currentItem?.provider ?? "") : ""); + const modelId = modelIdDraft ?? (props.mode === "edit" ? (currentItem?.model_id ?? "") : ""); + const displayName = + displayNameDraft ?? (props.mode === "edit" ? (currentItem?.display_name ?? "") : ""); + const settingsText = + settingsDraft ?? + (props.mode === "edit" && currentItem?.settings + ? JSON.stringify(currentItem.settings, null, 2) + : ""); + + const isBusy = createMutation.isPending || updateMutation.isPending; + const pageError = gatewaysQuery.error?.message ?? modelsQuery.error?.message ?? null; + + const title = props.mode === "create" ? "Add catalog model" : "Edit catalog model"; + const description = + props.mode === "create" + ? "Create a gateway model catalog entry for agent routing." + : "Update model metadata and settings for this catalog entry."; + + const handleSubmit = (event: FormEvent) => { + event.preventDefault(); + + if (!gatewayId) { + setError("Select a gateway first."); + return; + } + + const normalizedProvider = provider.trim().toLowerCase(); + const normalizedModelId = modelId.trim(); + const normalizedDisplayName = displayName.trim(); + + if (!normalizedProvider || !normalizedModelId || !normalizedDisplayName) { + setError("Provider, model ID, and display name are required."); + return; + } + + let settings: Record | undefined; + const normalizedSettings = settingsText.trim(); + if (normalizedSettings) { + try { + const parsed = JSON.parse(normalizedSettings) as unknown; + if (!parsed || typeof parsed !== "object" || Array.isArray(parsed)) { + throw new Error("settings must be an object"); + } + settings = parsed as Record; + } catch { + setError("Settings must be a valid JSON object."); + return; + } + } else if (props.mode === "edit") { + settings = {}; + } + + setError(null); + + if (props.mode === "create") { + createMutation.mutate({ + data: { + gateway_id: gatewayId, + provider: normalizedProvider, + model_id: normalizedModelId, + display_name: normalizedDisplayName, + settings, + }, + }); + return; + } + + if (!modelIdParam) { + setError("Missing model identifier."); + return; + } + + updateMutation.mutate({ + modelId: modelIdParam, + data: { + provider: normalizedProvider, + model_id: normalizedModelId, + display_name: normalizedDisplayName, + settings, + }, + }); + }; + + const missingEditItem = + props.mode === "edit" && + !modelsQuery.isLoading && + modelsQuery.data?.status === 200 && + !currentItem; + + return ( + +
+ {missingEditItem ? ( +
+ Catalog model entry not found. +
+ ) : ( +
+
+

Model details

+ +
+ + + + + + + + +