From 9340a74c42593374bc397c9251659efdc082d2f0 Mon Sep 17 00:00:00 2001 From: Abhimanyu Saharan Date: Mon, 9 Feb 2026 02:24:16 +0530 Subject: [PATCH] refactor: replace generic Exception handling with SQLAlchemyError in CRUD and session management --- backend/app/db/crud.py | 8 +++--- backend/app/db/session.py | 13 +++++++--- backend/tests/test_db_transaction_safety.py | 28 +++++++++++++++------ 3 files changed, 34 insertions(+), 15 deletions(-) diff --git a/backend/app/db/crud.py b/backend/app/db/crud.py index bd865b22..d5689c7a 100644 --- a/backend/app/db/crud.py +++ b/backend/app/db/crud.py @@ -5,7 +5,7 @@ from typing import Any, TypeVar, cast from sqlalchemy import delete as sql_delete from sqlalchemy import update as sql_update -from sqlalchemy.exc import IntegrityError +from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlmodel import SQLModel, select from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.sql.expression import SelectOfScalar @@ -24,7 +24,7 @@ class MultipleObjectsReturned(LookupError): async def _flush_or_rollback(session: AsyncSession) -> None: try: await session.flush() - except Exception: + except SQLAlchemyError: await session.rollback() raise @@ -32,7 +32,7 @@ async def _flush_or_rollback(session: AsyncSession) -> None: async def _commit_or_rollback(session: AsyncSession) -> None: try: await session.commit() - except Exception: + except SQLAlchemyError: await session.rollback() raise @@ -268,7 +268,7 @@ async def get_or_create( if existing is not None: return existing, False raise - except Exception: + except SQLAlchemyError: await session.rollback() raise diff --git a/backend/app/db/session.py b/backend/app/db/session.py index 8a71ba1b..9d1b7754 100644 --- a/backend/app/db/session.py +++ b/backend/app/db/session.py @@ -7,6 +7,7 @@ from pathlib import Path import anyio from alembic import command from alembic.config import Config +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine from sqlmodel import SQLModel from sqlmodel.ext.asyncio.session import AsyncSession @@ -69,9 +70,15 @@ async def get_session() -> AsyncGenerator[AsyncSession, None]: async with async_session_maker() as session: try: yield session - except Exception: + finally: + try: + in_txn = bool(session.in_transaction()) + except SQLAlchemyError: + logger.exception("Failed to inspect session transaction state.") + return + if not in_txn: + return try: await session.rollback() - except Exception: + except SQLAlchemyError: logger.exception("Failed to rollback session after request error.") - raise diff --git a/backend/tests/test_db_transaction_safety.py b/backend/tests/test_db_transaction_safety.py index 348a6bb7..9dd293b3 100644 --- a/backend/tests/test_db_transaction_safety.py +++ b/backend/tests/test_db_transaction_safety.py @@ -5,6 +5,7 @@ from types import SimpleNamespace from typing import Any import pytest +from sqlalchemy.exc import SQLAlchemyError from sqlmodel import Field, SQLModel from app.db import crud @@ -43,12 +44,17 @@ class _Maker: @pytest.mark.asyncio async def test_get_session_rolls_back_on_dependency_error(monkeypatch: pytest.MonkeyPatch) -> None: - fake_session = SimpleNamespace(rollbacks=0) + @dataclass + class _FakeDependencySession: + rollbacks: int = 0 - async def _rollback() -> None: - fake_session.rollbacks += 1 + def in_transaction(self) -> bool: + return True - fake_session.rollback = _rollback + async def rollback(self) -> None: + self.rollbacks += 1 + + fake_session = _FakeDependencySession() ctx = _SessionCtx(fake_session) monkeypatch.setattr(db_session, "async_session_maker", _Maker(ctx)) @@ -66,6 +72,9 @@ async def test_get_session_rolls_back_on_dependency_error(monkeypatch: pytest.Mo @pytest.mark.asyncio async def test_create_rolls_back_when_commit_fails() -> None: + class _CommitError(SQLAlchemyError): + pass + @dataclass class _FailCommitSession: rollback_calls: int = 0 @@ -82,7 +91,7 @@ async def test_create_rolls_back_when_commit_fails() -> None: return None async def commit(self) -> None: - raise RuntimeError("commit failed") + raise _CommitError("commit failed") async def rollback(self) -> None: self.rollback_calls += 1 @@ -92,7 +101,7 @@ async def test_create_rolls_back_when_commit_fails() -> None: session = _FailCommitSession() - with pytest.raises(RuntimeError, match="commit failed"): + with pytest.raises(SQLAlchemyError, match="commit failed"): await crud.create(session, _Thing, name="demo") assert session.rollback_calls == 1 @@ -101,6 +110,9 @@ async def test_create_rolls_back_when_commit_fails() -> None: @pytest.mark.asyncio async def test_delete_where_rolls_back_when_commit_fails() -> None: + class _CommitError(SQLAlchemyError): + pass + @dataclass class _FailCommitDmlSession: rollback_calls: int = 0 @@ -111,14 +123,14 @@ async def test_delete_where_rolls_back_when_commit_fails() -> None: return SimpleNamespace(rowcount=3) async def commit(self) -> None: - raise RuntimeError("commit failed") + raise _CommitError("commit failed") async def rollback(self) -> None: self.rollback_calls += 1 session = _FailCommitDmlSession() - with pytest.raises(RuntimeError, match="commit failed"): + with pytest.raises(SQLAlchemyError, match="commit failed"): await crud.delete_where(session, _Thing, commit=True) assert session.exec_calls == 1