Files
openclaw-mission-control/backend/tests/test_db_transaction_safety.py

145 lines
3.5 KiB
Python

# ruff: noqa
from __future__ import annotations
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Any
import pytest
from sqlalchemy.exc import SQLAlchemyError
from sqlmodel import Field, SQLModel
from app.db import crud
from app.db import session as db_session
class _Thing(SQLModel, table=True):
__tablename__ = "_test_thing"
id: int | None = Field(default=None, primary_key=True)
name: str
@dataclass
class _SessionCtx:
session: Any
entered: int = 0
exited: int = 0
async def __aenter__(self) -> Any:
self.entered += 1
return self.session
async def __aexit__(self, _exc_type: Any, _exc: Any, _tb: Any) -> bool:
self.exited += 1
return False
@dataclass
class _Maker:
ctx: _SessionCtx
def __call__(self) -> _SessionCtx:
return self.ctx
@pytest.mark.asyncio
async def test_get_session_rolls_back_on_dependency_error(monkeypatch: pytest.MonkeyPatch) -> None:
@dataclass
class _FakeDependencySession:
rollbacks: int = 0
@staticmethod
def in_transaction() -> bool:
return True
async def rollback(self) -> None:
self.rollbacks += 1
fake_session = _FakeDependencySession()
ctx = _SessionCtx(fake_session)
monkeypatch.setattr(db_session, "async_session_maker", _Maker(ctx))
generator = db_session.get_session()
yielded = await generator.__anext__()
assert yielded is fake_session
with pytest.raises(RuntimeError, match="boom"):
await generator.athrow(RuntimeError("boom"))
assert fake_session.rollbacks == 1
assert ctx.entered == 1
assert ctx.exited == 1
@pytest.mark.asyncio
async def test_create_rolls_back_when_commit_fails() -> None:
class _CommitError(SQLAlchemyError):
pass
@dataclass
class _FailCommitSession:
rollback_calls: int = 0
added: list[Any] = None
def __post_init__(self) -> None:
if self.added is None:
self.added = []
def add(self, value: Any) -> None:
self.added.append(value)
@staticmethod
async def flush() -> None:
return None
@staticmethod
async def commit() -> None:
raise _CommitError("commit failed")
async def rollback(self) -> None:
self.rollback_calls += 1
@staticmethod
async def refresh(_value: Any) -> None:
return None
session = _FailCommitSession()
with pytest.raises(SQLAlchemyError, match="commit failed"):
await crud.create(session, _Thing, name="demo")
assert session.rollback_calls == 1
assert len(session.added) == 1
@pytest.mark.asyncio
async def test_delete_where_rolls_back_when_commit_fails() -> None:
class _CommitError(SQLAlchemyError):
pass
@dataclass
class _FailCommitDmlSession:
rollback_calls: int = 0
exec_calls: int = 0
async def exec(self, _statement: Any) -> Any:
self.exec_calls += 1
return SimpleNamespace(rowcount=3)
@staticmethod
async def commit() -> None:
raise _CommitError("commit failed")
async def rollback(self) -> None:
self.rollback_calls += 1
session = _FailCommitDmlSession()
with pytest.raises(SQLAlchemyError, match="commit failed"):
await crud.delete_where(session, _Thing, commit=True)
assert session.exec_calls == 1
assert session.rollback_calls == 1