1+ import asyncio
12import datetime
23import os
34import unittest
45from unittest .mock import patch
56
67import django
8+ from asgiref .sync import sync_to_async
79from django .contrib .auth .models import User
810from django .db import connection
911from django .db .models import Count
1618
1719from ..base import BaseTestCase
1820from ..models import PostgresJSON
21+ from ..sync import database_sync_to_async
22+
23+
24+ def sql_call (use_iterator = False ):
25+ qs = User .objects .all ()
26+ if use_iterator :
27+ qs = qs .iterator ()
28+ return list (qs )
1929
2030
2131class SQLPanelTestCase (BaseTestCase ):
@@ -30,7 +40,7 @@ def test_disabled(self):
3040 def test_recording (self ):
3141 self .assertEqual (len (self .panel ._queries ), 0 )
3242
33- list ( User . objects . all () )
43+ sql_call ( )
3444
3545 # ensure query was logged
3646 self .assertEqual (len (self .panel ._queries ), 1 )
@@ -49,29 +59,64 @@ def test_recording(self):
4959 def test_recording_chunked_cursor (self ):
5060 self .assertEqual (len (self .panel ._queries ), 0 )
5161
52- list ( User . objects . all (). iterator () )
62+ sql_call ( use_iterator = True )
5363
5464 # ensure query was logged
5565 self .assertEqual (len (self .panel ._queries ), 1 )
5666
57- @patch ("debug_toolbar.panels.sql.tracking.state" , wraps = sql_tracking .state )
58- def test_cursor_wrapper_singleton (self , mock_state ):
59- list (User .objects .all ())
67+ @patch (
68+ "debug_toolbar.panels.sql.tracking.NormalCursorWrapper" ,
69+ wraps = sql_tracking .NormalCursorWrapper ,
70+ )
71+ def test_cursor_wrapper_singleton (self , mock_wrapper ):
72+ sql_call ()
6073
6174 # ensure that cursor wrapping is applied only once
62- self .assertEqual (mock_state . Wrapper .call_count , 1 )
75+ self .assertEqual (mock_wrapper .call_count , 1 )
6376
64- @patch ("debug_toolbar.panels.sql.tracking.state" , wraps = sql_tracking .state )
65- def test_chunked_cursor_wrapper_singleton (self , mock_state ):
66- list (User .objects .all ().iterator ())
77+ @patch (
78+ "debug_toolbar.panels.sql.tracking.NormalCursorWrapper" ,
79+ wraps = sql_tracking .NormalCursorWrapper ,
80+ )
81+ def test_chunked_cursor_wrapper_singleton (self , mock_wrapper ):
82+ sql_call (use_iterator = True )
6783
6884 # ensure that cursor wrapping is applied only once
69- self .assertEqual (mock_state .Wrapper .call_count , 1 )
85+ self .assertEqual (mock_wrapper .call_count , 1 )
86+
87+ @patch (
88+ "debug_toolbar.panels.sql.tracking.NormalCursorWrapper" ,
89+ wraps = sql_tracking .NormalCursorWrapper ,
90+ )
91+ async def test_cursor_wrapper_async (self , mock_wrapper ):
92+ await sync_to_async (sql_call )()
93+
94+ self .assertEqual (mock_wrapper .call_count , 1 )
95+
96+ @patch (
97+ "debug_toolbar.panels.sql.tracking.NormalCursorWrapper" ,
98+ wraps = sql_tracking .NormalCursorWrapper ,
99+ )
100+ async def test_cursor_wrapper_asyncio_ctx (self , mock_wrapper ):
101+ self .assertTrue (sql_tracking .recording .get ())
102+ await sync_to_async (sql_call )()
103+
104+ async def task ():
105+ sql_tracking .recording .set (False )
106+ # Calling this in another context requires the db connections
107+ # to be closed properly.
108+ await database_sync_to_async (sql_call )()
109+
110+ # Ensure this is called in another context
111+ await asyncio .create_task (task ())
112+ # Because it was called in another context, it should not have affected ours
113+ self .assertTrue (sql_tracking .recording .get ())
114+ self .assertEqual (mock_wrapper .call_count , 1 )
70115
71116 def test_generate_server_timing (self ):
72117 self .assertEqual (len (self .panel ._queries ), 0 )
73118
74- list ( User . objects . all () )
119+ sql_call ( )
75120
76121 response = self .panel .process_request (self .request )
77122 self .panel .generate_stats (self .request , response )
@@ -337,7 +382,7 @@ def test_disable_stacktraces(self):
337382 self .assertEqual (len (self .panel ._queries ), 0 )
338383
339384 with self .settings (DEBUG_TOOLBAR_CONFIG = {"ENABLE_STACKTRACES" : False }):
340- list ( User . objects . all () )
385+ sql_call ( )
341386
342387 # ensure query was logged
343388 self .assertEqual (len (self .panel ._queries ), 1 )
0 commit comments