Skip to content

Commit d935165

Browse files
committed
Fix psycopg3 tests
Several tests (such as SQLPanelTestCase.test_cursor_wrapper_singleton) are written to ensure that only a single cursor wrapper is instantiated during a test. However, this fails when using psycopg3, since the .last_executed_query() call in NormalCursorWrapper._record() ends up creating an additional cursor (via [1]). To avoid this, use a ._djdt_in_record attribute on the database wrapper. Make the NormalCursorWrapper._record() method set ._djdt_in_record to True on entry and reset it to False on exit. Then in the overridden database wrapper .cursor() and .chunked_cursor() methods, check the ._djdt_in_record attribute and return the original cursor without wrapping if the attribute is True. [1] https://github.com/django/django/blob/879e5d587b84e6fc961829611999431778eb9f6a/django/db/backends/postgresql/psycopg_any.py#L21
1 parent 13ce4c6 commit d935165

File tree

1 file changed

+92
-80
lines changed

1 file changed

+92
-80
lines changed

debug_toolbar/panels/sql/tracking.py

Lines changed: 92 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,21 @@ def cursor(*args, **kwargs):
4343
# See:
4444
# https://github.com/jazzband/django-debug-toolbar/pull/615
4545
# https://github.com/jazzband/django-debug-toolbar/pull/896
46+
cursor = connection._djdt_cursor(*args, **kwargs)
47+
if connection._djdt_in_record:
48+
return cursor
4649
if allow_sql.get():
4750
wrapper = NormalCursorWrapper
4851
else:
4952
wrapper = ExceptionCursorWrapper
50-
return wrapper(connection._djdt_cursor(*args, **kwargs), connection, panel)
53+
return wrapper(cursor, connection, panel)
5154

5255
def chunked_cursor(*args, **kwargs):
5356
# prevent double wrapping
5457
# solves https://github.com/jazzband/django-debug-toolbar/issues/1239
5558
cursor = connection._djdt_chunked_cursor(*args, **kwargs)
59+
if connection._djdt_in_record:
60+
return cursor
5661
if not isinstance(cursor, BaseCursorWrapper):
5762
if allow_sql.get():
5863
wrapper = NormalCursorWrapper
@@ -63,6 +68,7 @@ def chunked_cursor(*args, **kwargs):
6368

6469
connection.cursor = cursor
6570
connection.chunked_cursor = chunked_cursor
71+
connection._djdt_in_record = False
6672

6773

6874
def unwrap_cursor(connection):
@@ -154,90 +160,96 @@ def _decode(self, param):
154160
return "(encoded string)"
155161

156162
def _record(self, method, sql, params):
157-
alias = self.db.alias
158-
vendor = self.db.vendor
159-
160-
if vendor == "postgresql":
161-
# The underlying DB connection (as opposed to Django's wrapper)
162-
conn = self.db.connection
163-
initial_conn_status = conn.info.transaction_status
164-
165-
start_time = time()
163+
self.db._djdt_in_record = True
166164
try:
167-
return method(sql, params)
168-
finally:
169-
stop_time = time()
170-
duration = (stop_time - start_time) * 1000
171-
_params = ""
172-
try:
173-
_params = json.dumps(self._decode(params))
174-
except TypeError:
175-
pass # object not JSON serializable
176-
template_info = get_template_info()
177-
178-
# Sql might be an object (such as psycopg Composed).
179-
# For logging purposes, make sure it's str.
180-
if vendor == "postgresql" and not isinstance(sql, str):
181-
sql = sql.as_string(conn)
182-
else:
183-
sql = str(sql)
184-
185-
params = {
186-
"vendor": vendor,
187-
"alias": alias,
188-
"sql": self.db.ops.last_executed_query(
189-
self.cursor, sql, self._quote_params(params)
190-
),
191-
"duration": duration,
192-
"raw_sql": sql,
193-
"params": _params,
194-
"raw_params": params,
195-
"stacktrace": get_stack_trace(skip=2),
196-
"start_time": start_time,
197-
"stop_time": stop_time,
198-
"is_slow": duration > dt_settings.get_config()["SQL_WARNING_THRESHOLD"],
199-
"is_select": sql.lower().strip().startswith("select"),
200-
"template_info": template_info,
201-
}
165+
alias = self.db.alias
166+
vendor = self.db.vendor
202167

203168
if vendor == "postgresql":
204-
# If an erroneous query was ran on the connection, it might
205-
# be in a state where checking isolation_level raises an
206-
# exception.
169+
# The underlying DB connection (as opposed to Django's wrapper)
170+
conn = self.db.connection
171+
initial_conn_status = conn.info.transaction_status
172+
173+
start_time = time()
174+
try:
175+
return method(sql, params)
176+
finally:
177+
stop_time = time()
178+
duration = (stop_time - start_time) * 1000
179+
_params = ""
207180
try:
208-
iso_level = conn.isolation_level
209-
except conn.InternalError:
210-
iso_level = "unknown"
211-
# PostgreSQL does not expose any sort of transaction ID, so it is
212-
# necessary to generate synthetic transaction IDs here. If the
213-
# connection was not in a transaction when the query started, and was
214-
# after the query finished, a new transaction definitely started, so get
215-
# a new transaction ID from logger.new_transaction_id(). If the query
216-
# was in a transaction both before and after executing, make the
217-
# assumption that it is the same transaction and get the current
218-
# transaction ID from logger.current_transaction_id(). There is an edge
219-
# case where Django can start a transaction before the first query
220-
# executes, so in that case logger.current_transaction_id() will
221-
# generate a new transaction ID since one does not already exist.
222-
final_conn_status = conn.info.transaction_status
223-
if final_conn_status == STATUS_IN_TRANSACTION:
224-
if initial_conn_status == STATUS_IN_TRANSACTION:
225-
trans_id = self.logger.current_transaction_id(alias)
226-
else:
227-
trans_id = self.logger.new_transaction_id(alias)
181+
_params = json.dumps(self._decode(params))
182+
except TypeError:
183+
pass # object not JSON serializable
184+
template_info = get_template_info()
185+
186+
# Sql might be an object (such as psycopg Composed).
187+
# For logging purposes, make sure it's str.
188+
if vendor == "postgresql" and not isinstance(sql, str):
189+
sql = sql.as_string(conn)
228190
else:
229-
trans_id = None
230-
231-
params.update(
232-
{
233-
"trans_id": trans_id,
234-
"trans_status": conn.info.transaction_status,
235-
"iso_level": iso_level,
236-
}
237-
)
238-
239-
# We keep `sql` to maintain backwards compatibility
240-
self.logger.record(**params)
191+
sql = str(sql)
192+
193+
params = {
194+
"vendor": vendor,
195+
"alias": alias,
196+
"sql": self.db.ops.last_executed_query(
197+
self.cursor, sql, self._quote_params(params)
198+
),
199+
"duration": duration,
200+
"raw_sql": sql,
201+
"params": _params,
202+
"raw_params": params,
203+
"stacktrace": get_stack_trace(skip=2),
204+
"start_time": start_time,
205+
"stop_time": stop_time,
206+
"is_slow": (
207+
duration > dt_settings.get_config()["SQL_WARNING_THRESHOLD"]
208+
),
209+
"is_select": sql.lower().strip().startswith("select"),
210+
"template_info": template_info,
211+
}
212+
213+
if vendor == "postgresql":
214+
# If an erroneous query was ran on the connection, it might
215+
# be in a state where checking isolation_level raises an
216+
# exception.
217+
try:
218+
iso_level = conn.isolation_level
219+
except conn.InternalError:
220+
iso_level = "unknown"
221+
# PostgreSQL does not expose any sort of transaction ID, so it is
222+
# necessary to generate synthetic transaction IDs here. If the
223+
# connection was not in a transaction when the query started, and was
224+
# after the query finished, a new transaction definitely started, so get
225+
# a new transaction ID from logger.new_transaction_id(). If the query
226+
# was in a transaction both before and after executing, make the
227+
# assumption that it is the same transaction and get the current
228+
# transaction ID from logger.current_transaction_id(). There is an edge
229+
# case where Django can start a transaction before the first query
230+
# executes, so in that case logger.current_transaction_id() will
231+
# generate a new transaction ID since one does not already exist.
232+
final_conn_status = conn.info.transaction_status
233+
if final_conn_status == STATUS_IN_TRANSACTION:
234+
if initial_conn_status == STATUS_IN_TRANSACTION:
235+
trans_id = self.logger.current_transaction_id(alias)
236+
else:
237+
trans_id = self.logger.new_transaction_id(alias)
238+
else:
239+
trans_id = None
240+
241+
params.update(
242+
{
243+
"trans_id": trans_id,
244+
"trans_status": conn.info.transaction_status,
245+
"iso_level": iso_level,
246+
}
247+
)
248+
249+
# We keep `sql` to maintain backwards compatibility
250+
self.logger.record(**params)
251+
finally:
252+
self.db._djdt_in_record = False
241253

242254
def callproc(self, procname, params=None):
243255
return self._record(self.cursor.callproc, procname, params)

0 commit comments

Comments
 (0)