|
4 | 4 | from time import time |
5 | 5 |
|
6 | 6 | import django.test.testcases |
| 7 | +from django.db.backends.utils import CursorWrapper |
7 | 8 | from django.utils.encoding import force_str |
8 | 9 |
|
9 | 10 | from debug_toolbar import settings as dt_settings |
@@ -57,54 +58,47 @@ def cursor(*args, **kwargs): |
57 | 58 | wrapper = NormalCursorWrapper |
58 | 59 | else: |
59 | 60 | wrapper = ExceptionCursorWrapper |
60 | | - return wrapper(cursor, connection, logger) |
| 61 | + return wrapper(cursor.cursor, connection, logger) |
61 | 62 |
|
62 | 63 | def chunked_cursor(*args, **kwargs): |
63 | 64 | # prevent double wrapping |
64 | 65 | # solves https://github.com/jazzband/django-debug-toolbar/issues/1239 |
65 | 66 | logger = connection._djdt_logger |
66 | 67 | cursor = connection._djdt_chunked_cursor(*args, **kwargs) |
67 | | - if logger is not None and not isinstance(cursor, BaseCursorWrapper): |
| 68 | + if logger is not None and not isinstance(cursor, DjDTCursorWrapper): |
68 | 69 | if allow_sql.get(): |
69 | 70 | wrapper = NormalCursorWrapper |
70 | 71 | else: |
71 | 72 | wrapper = ExceptionCursorWrapper |
72 | | - return wrapper(cursor, connection, logger) |
| 73 | + return wrapper(cursor.cursor, connection, logger) |
73 | 74 | return cursor |
74 | 75 |
|
75 | 76 | connection.cursor = cursor |
76 | 77 | connection.chunked_cursor = chunked_cursor |
77 | 78 |
|
78 | 79 |
|
79 | | -class BaseCursorWrapper: |
80 | | - pass |
| 80 | +class DjDTCursorWrapper(CursorWrapper): |
| 81 | + def __init__(self, cursor, db, logger): |
| 82 | + super().__init__(cursor, db) |
| 83 | + # logger must implement a ``record`` method |
| 84 | + self.logger = logger |
81 | 85 |
|
82 | 86 |
|
83 | | -class ExceptionCursorWrapper(BaseCursorWrapper): |
| 87 | +class ExceptionCursorWrapper(DjDTCursorWrapper): |
84 | 88 | """ |
85 | 89 | Wraps a cursor and raises an exception on any operation. |
86 | 90 | Used in Templates panel. |
87 | 91 | """ |
88 | 92 |
|
89 | | - def __init__(self, cursor, db, logger): |
90 | | - pass |
91 | | - |
92 | 93 | def __getattr__(self, attr): |
93 | 94 | raise SQLQueryTriggered() |
94 | 95 |
|
95 | 96 |
|
96 | | -class NormalCursorWrapper(BaseCursorWrapper): |
| 97 | +class NormalCursorWrapper(DjDTCursorWrapper): |
97 | 98 | """ |
98 | 99 | Wraps a cursor and logs queries. |
99 | 100 | """ |
100 | 101 |
|
101 | | - def __init__(self, cursor, db, logger): |
102 | | - self.cursor = cursor |
103 | | - # Instance of a BaseDatabaseWrapper subclass |
104 | | - self.db = db |
105 | | - # logger must implement a ``record`` method |
106 | | - self.logger = logger |
107 | | - |
108 | 102 | def _quote_expr(self, element): |
109 | 103 | if isinstance(element, str): |
110 | 104 | return "'%s'" % element.replace("'", "''") |
@@ -246,22 +240,10 @@ def _record(self, method, sql, params): |
246 | 240 | self.logger.record(**params) |
247 | 241 |
|
248 | 242 | def callproc(self, procname, params=None): |
249 | | - return self._record(self.cursor.callproc, procname, params) |
| 243 | + return self._record(super().callproc, procname, params) |
250 | 244 |
|
251 | 245 | def execute(self, sql, params=None): |
252 | | - return self._record(self.cursor.execute, sql, params) |
| 246 | + return self._record(super().execute, sql, params) |
253 | 247 |
|
254 | 248 | def executemany(self, sql, param_list): |
255 | | - return self._record(self.cursor.executemany, sql, param_list) |
256 | | - |
257 | | - def __getattr__(self, attr): |
258 | | - return getattr(self.cursor, attr) |
259 | | - |
260 | | - def __iter__(self): |
261 | | - return iter(self.cursor) |
262 | | - |
263 | | - def __enter__(self): |
264 | | - return self |
265 | | - |
266 | | - def __exit__(self, type, value, traceback): |
267 | | - self.close() |
| 249 | + return self._record(super().executemany, sql, param_list) |
0 commit comments