106 lines
3.2 KiB
Python
106 lines
3.2 KiB
Python
# ruff: noqa: S101
|
|
"""Regression tests for agent model-assignment update normalization."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from typing import Any
|
|
from uuid import UUID, uuid4
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock
|
|
|
|
from app.services.openclaw.provisioning_db import AgentLifecycleService
|
|
|
|
|
|
class _NoAutoflush:
|
|
def __init__(self, session: "_SessionStub") -> None:
|
|
self._session = session
|
|
|
|
def __enter__(self) -> None:
|
|
self._session.in_no_autoflush = True
|
|
return None
|
|
|
|
def __exit__(self, exc_type, exc, tb) -> bool:
|
|
self._session.in_no_autoflush = False
|
|
return False
|
|
|
|
|
|
class _SessionStub:
|
|
def __init__(self, valid_ids: set[UUID]) -> None:
|
|
self.valid_ids = valid_ids
|
|
self.in_no_autoflush = False
|
|
self.commits = 0
|
|
|
|
@property
|
|
def no_autoflush(self) -> _NoAutoflush:
|
|
return _NoAutoflush(self)
|
|
|
|
async def exec(self, _statement: Any) -> list[UUID]:
|
|
if not self.in_no_autoflush:
|
|
raise AssertionError("Expected normalize query to run under no_autoflush.")
|
|
return list(self.valid_ids)
|
|
|
|
def add(self, _model: Any) -> None:
|
|
return None
|
|
|
|
async def commit(self) -> None:
|
|
self.commits += 1
|
|
|
|
async def refresh(self, _model: Any) -> None:
|
|
return None
|
|
|
|
|
|
@dataclass
|
|
class _AgentStub:
|
|
gateway_id: UUID
|
|
board_id: UUID | None = None
|
|
is_board_lead: bool = False
|
|
openclaw_session_id: str | None = None
|
|
primary_model_id: UUID | None = None
|
|
fallback_model_ids: list[str] | None = None
|
|
updated_at: datetime | None = None
|
|
heartbeat_config: dict[str, Any] | None = field(
|
|
default_factory=lambda: {"interval_seconds": 5},
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_normalize_agent_model_assignments_uses_no_autoflush() -> None:
|
|
primary = uuid4()
|
|
fallback = uuid4()
|
|
session = _SessionStub({primary, fallback})
|
|
service = AgentLifecycleService(session) # type: ignore[arg-type]
|
|
|
|
normalized_primary, normalized_fallback = await service.normalize_agent_model_assignments(
|
|
gateway_id=uuid4(),
|
|
primary_model_id=primary,
|
|
fallback_model_ids=[primary, fallback],
|
|
)
|
|
|
|
assert normalized_primary == primary
|
|
assert normalized_fallback == [str(fallback)]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_apply_agent_update_mutations_coerces_fallback_ids_to_strings(monkeypatch) -> None:
|
|
primary = uuid4()
|
|
fallback = uuid4()
|
|
session = _SessionStub({primary, fallback})
|
|
service = AgentLifecycleService(session) # type: ignore[arg-type]
|
|
monkeypatch.setattr(service, "get_main_agent_gateway", AsyncMock(return_value=None))
|
|
|
|
agent = _AgentStub(gateway_id=uuid4())
|
|
updates: dict[str, Any] = {
|
|
"primary_model_id": primary,
|
|
"fallback_model_ids": [primary, fallback, fallback],
|
|
}
|
|
|
|
await service.apply_agent_update_mutations(agent=agent, updates=updates, make_main=None) # type: ignore[arg-type]
|
|
|
|
assert updates["fallback_model_ids"] == [str(primary), str(fallback)]
|
|
assert agent.primary_model_id == primary
|
|
assert agent.fallback_model_ids == [str(fallback)]
|
|
assert session.commits == 1
|