refactor: replace generic Exception handling with SQLAlchemyError in CRUD and session management
This commit is contained in:
@@ -5,6 +5,7 @@ from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from app.db import crud
|
||||
@@ -43,12 +44,17 @@ class _Maker:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_rolls_back_on_dependency_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fake_session = SimpleNamespace(rollbacks=0)
|
||||
@dataclass
|
||||
class _FakeDependencySession:
|
||||
rollbacks: int = 0
|
||||
|
||||
async def _rollback() -> None:
|
||||
fake_session.rollbacks += 1
|
||||
def in_transaction(self) -> bool:
|
||||
return True
|
||||
|
||||
fake_session.rollback = _rollback
|
||||
async def rollback(self) -> None:
|
||||
self.rollbacks += 1
|
||||
|
||||
fake_session = _FakeDependencySession()
|
||||
ctx = _SessionCtx(fake_session)
|
||||
monkeypatch.setattr(db_session, "async_session_maker", _Maker(ctx))
|
||||
|
||||
@@ -66,6 +72,9 @@ async def test_get_session_rolls_back_on_dependency_error(monkeypatch: pytest.Mo
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_rolls_back_when_commit_fails() -> None:
|
||||
class _CommitError(SQLAlchemyError):
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class _FailCommitSession:
|
||||
rollback_calls: int = 0
|
||||
@@ -82,7 +91,7 @@ async def test_create_rolls_back_when_commit_fails() -> None:
|
||||
return None
|
||||
|
||||
async def commit(self) -> None:
|
||||
raise RuntimeError("commit failed")
|
||||
raise _CommitError("commit failed")
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rollback_calls += 1
|
||||
@@ -92,7 +101,7 @@ async def test_create_rolls_back_when_commit_fails() -> None:
|
||||
|
||||
session = _FailCommitSession()
|
||||
|
||||
with pytest.raises(RuntimeError, match="commit failed"):
|
||||
with pytest.raises(SQLAlchemyError, match="commit failed"):
|
||||
await crud.create(session, _Thing, name="demo")
|
||||
|
||||
assert session.rollback_calls == 1
|
||||
@@ -101,6 +110,9 @@ async def test_create_rolls_back_when_commit_fails() -> None:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_where_rolls_back_when_commit_fails() -> None:
|
||||
class _CommitError(SQLAlchemyError):
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class _FailCommitDmlSession:
|
||||
rollback_calls: int = 0
|
||||
@@ -111,14 +123,14 @@ async def test_delete_where_rolls_back_when_commit_fails() -> None:
|
||||
return SimpleNamespace(rowcount=3)
|
||||
|
||||
async def commit(self) -> None:
|
||||
raise RuntimeError("commit failed")
|
||||
raise _CommitError("commit failed")
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rollback_calls += 1
|
||||
|
||||
session = _FailCommitDmlSession()
|
||||
|
||||
with pytest.raises(RuntimeError, match="commit failed"):
|
||||
with pytest.raises(SQLAlchemyError, match="commit failed"):
|
||||
await crud.delete_where(session, _Thing, commit=True)
|
||||
|
||||
assert session.exec_calls == 1
|
||||
|
||||
Reference in New Issue
Block a user