diff --git a/pytest_flask_sqlalchemy/fixtures.py b/pytest_flask_sqlalchemy/fixtures.py index 0396bd6..6b3f2eb 100644 --- a/pytest_flask_sqlalchemy/fixtures.py +++ b/pytest_flask_sqlalchemy/fixtures.py @@ -1,5 +1,5 @@ -import os import contextlib +import os import pytest import sqlalchemy as sa @@ -144,6 +144,12 @@ def raw_connection(): engine.raw_connection = raw_connection + # Fix SessionTransaction._connection_for_bind caching + @sa.event.listens_for(session, 'after_begin') + def after_begin(session, transaction, conn): + if engine not in transaction._connections: + transaction._connections[engine] = transaction._connections[conn] + for mocked_engine in pytestconfig._mocked_engines: mocker.patch(mocked_engine, new=engine) diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py index d1e953f..f775a0a 100644 --- a/tests/test_fixtures.py +++ b/tests/test_fixtures.py @@ -336,3 +336,32 @@ def test_delete_message(account_address, db_session): result = db_testdir.runpytest() result.assert_outcomes(passed=1) + + +def test_rollback_nested(db_testdir): + ''' + Test that creating objects and emitting SQL in the ORM won't bleed into + other tests. + ''' + # Load tests from file + db_testdir.makepyfile(""" + def test_rollback_nested(person, db_session, caplog): + assert db_session.query(person).count() == 0 + n1 = db_session.begin_nested() + db_session.add(person()) + assert db_session.query(person).count() == 1 + + n2 = db_session.begin_nested() + db_session.add(person()) + assert db_session.query(person).count() == 2 + + n2.rollback() + print(db_session.bind.mock_calls) + assert db_session.query(person).count() == 1 + n1.rollback() + assert db_session.query(person).count() == 0 + """) + + # Run tests + result = db_testdir.runpytest() + result.assert_outcomes(passed=1)