1818
1919from ..base import BaseTestCase
2020from ..models import PostgresJSON
21+ from ..sync import database_sync_to_async
22+
23+
24+ def sql_call (use_iterator = False ):
25+ qs = User .objects .all ()
26+ if use_iterator :
27+ qs = qs .iterator ()
28+ return list (qs )
2129
2230
2331class SQLPanelTestCase (BaseTestCase ):
@@ -32,7 +40,7 @@ def test_disabled(self):
3240 def test_recording (self ):
3341 self .assertEqual (len (self .panel ._queries ), 0 )
3442
35- list ( User . objects . all () )
43+ sql_call ( )
3644
3745 # ensure query was logged
3846 self .assertEqual (len (self .panel ._queries ), 1 )
@@ -51,7 +59,7 @@ def test_recording(self):
5159 def test_recording_chunked_cursor (self ):
5260 self .assertEqual (len (self .panel ._queries ), 0 )
5361
54- list ( User . objects . all (). iterator () )
62+ sql_call ( use_iterator = True )
5563
5664 # ensure query was logged
5765 self .assertEqual (len (self .panel ._queries ), 1 )
@@ -61,7 +69,7 @@ def test_recording_chunked_cursor(self):
6169 wraps = sql_tracking .NormalCursorWrapper ,
6270 )
6371 def test_cursor_wrapper_singleton (self , mock_wrapper ):
64- list ( User . objects . all () )
72+ sql_call ( )
6573
6674 # ensure that cursor wrapping is applied only once
6775 self .assertEqual (mock_wrapper .call_count , 1 )
@@ -71,7 +79,7 @@ def test_cursor_wrapper_singleton(self, mock_wrapper):
7179 wraps = sql_tracking .NormalCursorWrapper ,
7280 )
7381 def test_chunked_cursor_wrapper_singleton (self , mock_wrapper ):
74- list ( User . objects . all (). iterator () )
82+ sql_call ( use_iterator = True )
7583
7684 # ensure that cursor wrapping is applied only once
7785 self .assertEqual (mock_wrapper .call_count , 1 )
@@ -81,7 +89,7 @@ def test_chunked_cursor_wrapper_singleton(self, mock_wrapper):
8189 wraps = sql_tracking .NormalCursorWrapper ,
8290 )
8391 async def test_cursor_wrapper_async (self , mock_wrapper ):
84- await sync_to_async (list )( User . objects . all () )
92+ await sync_to_async (sql_call )( )
8593
8694 self .assertEqual (mock_wrapper .call_count , 1 )
8795
@@ -91,11 +99,13 @@ async def test_cursor_wrapper_async(self, mock_wrapper):
9199 )
92100 async def test_cursor_wrapper_asyncio_ctx (self , mock_wrapper ):
93101 self .assertTrue (sql_tracking .recording .get ())
94- await sync_to_async (list )( User . objects . all () )
102+ await sync_to_async (sql_call )( )
95103
96104 async def task ():
97105 sql_tracking .recording .set (False )
98- await sync_to_async (list )(User .objects .all ())
106+ # Calling this in another context requires the db connections
107+ # to be closed properly.
108+ await database_sync_to_async (sql_call )()
99109
100110 # Ensure this is called in another context
101111 await asyncio .create_task (task ())
@@ -106,7 +116,7 @@ async def task():
106116 def test_generate_server_timing (self ):
107117 self .assertEqual (len (self .panel ._queries ), 0 )
108118
109- list ( User . objects . all () )
119+ sql_call ( )
110120
111121 response = self .panel .process_request (self .request )
112122 self .panel .generate_stats (self .request , response )
@@ -372,7 +382,7 @@ def test_disable_stacktraces(self):
372382 self .assertEqual (len (self .panel ._queries ), 0 )
373383
374384 with self .settings (DEBUG_TOOLBAR_CONFIG = {"ENABLE_STACKTRACES" : False }):
375- list ( User . objects . all () )
385+ sql_call ( )
376386
377387 # ensure query was logged
378388 self .assertEqual (len (self .panel ._queries ), 1 )
0 commit comments