@@ -41,6 +41,7 @@ PyUnstable_AtExit(PyInterpreterState *interp,
4141 callback -> next = NULL ;
4242
4343 struct atexit_state * state = & interp -> atexit ;
44+ _PyAtExit_LockCallbacks (state );
4445 atexit_callback * top = state -> ll_callbacks ;
4546 if (top == NULL ) {
4647 state -> ll_callbacks = callback ;
@@ -49,36 +50,16 @@ PyUnstable_AtExit(PyInterpreterState *interp,
4950 callback -> next = top ;
5051 state -> ll_callbacks = callback ;
5152 }
53+ _PyAtExit_UnlockCallbacks (state );
5254 return 0 ;
5355}
5456
5557
56- static void
57- atexit_delete_cb (struct atexit_state * state , int i )
58- {
59- atexit_py_callback * cb = state -> callbacks [i ];
60- state -> callbacks [i ] = NULL ;
61-
62- Py_DECREF (cb -> func );
63- Py_DECREF (cb -> args );
64- Py_XDECREF (cb -> kwargs );
65- PyMem_Free (cb );
66- }
67-
68-
6958/* Clear all callbacks without calling them */
7059static void
7160atexit_cleanup (struct atexit_state * state )
7261{
73- atexit_py_callback * cb ;
74- for (int i = 0 ; i < state -> ncallbacks ; i ++ ) {
75- cb = state -> callbacks [i ];
76- if (cb == NULL )
77- continue ;
78-
79- atexit_delete_cb (state , i );
80- }
81- state -> ncallbacks = 0 ;
62+ PyList_Clear (state -> callbacks );
8263}
8364
8465
@@ -89,23 +70,21 @@ _PyAtExit_Init(PyInterpreterState *interp)
8970 // _PyAtExit_Init() must only be called once
9071 assert (state -> callbacks == NULL );
9172
92- state -> callback_len = 32 ;
93- state -> ncallbacks = 0 ;
94- state -> callbacks = PyMem_New (atexit_py_callback * , state -> callback_len );
73+ state -> callbacks = PyList_New (0 );
9574 if (state -> callbacks == NULL ) {
9675 return _PyStatus_NO_MEMORY ();
9776 }
9877 return _PyStatus_OK ();
9978}
10079
101-
10280void
10381_PyAtExit_Fini (PyInterpreterState * interp )
10482{
83+ // In theory, there shouldn't be any threads left by now, so we
84+ // won't lock this.
10585 struct atexit_state * state = & interp -> atexit ;
10686 atexit_cleanup (state );
107- PyMem_Free (state -> callbacks );
108- state -> callbacks = NULL ;
87+ Py_CLEAR (state -> callbacks );
10988
11089 atexit_callback * next = state -> ll_callbacks ;
11190 state -> ll_callbacks = NULL ;
@@ -120,35 +99,44 @@ _PyAtExit_Fini(PyInterpreterState *interp)
12099 }
121100}
122101
123-
124102static void
125103atexit_callfuncs (struct atexit_state * state )
126104{
127105 assert (!PyErr_Occurred ());
106+ assert (state -> callbacks != NULL );
107+ assert (PyList_CheckExact (state -> callbacks ));
128108
129- if (state -> ncallbacks == 0 ) {
109+ // Create a copy of the list for thread safety
110+ PyObject * copy = PyList_GetSlice (state -> callbacks , 0 , PyList_GET_SIZE (state -> callbacks ));
111+ if (copy == NULL )
112+ {
113+ PyErr_WriteUnraisable (NULL );
130114 return ;
131115 }
132116
133- for (int i = state -> ncallbacks - 1 ; i >= 0 ; i -- ) {
134- atexit_py_callback * cb = state -> callbacks [i ];
135- if (cb == NULL ) {
136- continue ;
137- }
117+ for (Py_ssize_t i = 0 ; i < PyList_GET_SIZE (copy ); ++ i ) {
118+ // We don't have to worry about evil borrowed references, because
119+ // no other threads can access this list.
120+ PyObject * tuple = PyList_GET_ITEM (copy , i );
121+ assert (PyTuple_CheckExact (tuple ));
122+
123+ PyObject * func = PyTuple_GET_ITEM (tuple , 0 );
124+ PyObject * args = PyTuple_GET_ITEM (tuple , 1 );
125+ PyObject * kwargs = PyTuple_GET_ITEM (tuple , 2 );
138126
139- // bpo-46025: Increment the refcount of cb-> func as the call itself may unregister it
140- PyObject * the_func = Py_NewRef ( cb -> func );
141- PyObject * res = PyObject_Call ( cb -> func , cb -> args , cb -> kwargs );
127+ PyObject * res = PyObject_Call ( func ,
128+ args ,
129+ kwargs == Py_None ? NULL : kwargs );
142130 if (res == NULL ) {
143131 PyErr_FormatUnraisable (
144- "Exception ignored in atexit callback %R" , the_func );
132+ "Exception ignored in atexit callback %R" , func );
145133 }
146134 else {
147135 Py_DECREF (res );
148136 }
149- Py_DECREF (the_func );
150137 }
151138
139+ Py_DECREF (copy );
152140 atexit_cleanup (state );
153141
154142 assert (!PyErr_Occurred ());
@@ -194,33 +182,27 @@ atexit_register(PyObject *module, PyObject *args, PyObject *kwargs)
194182 "the first argument must be callable" );
195183 return NULL ;
196184 }
185+ PyObject * func_args = PyTuple_GetSlice (args , 1 , PyTuple_GET_SIZE (args ));
186+ PyObject * func_kwargs = kwargs ;
197187
198- struct atexit_state * state = get_atexit_state ();
199- if (state -> ncallbacks >= state -> callback_len ) {
200- atexit_py_callback * * r ;
201- state -> callback_len += 16 ;
202- size_t size = sizeof (atexit_py_callback * ) * (size_t )state -> callback_len ;
203- r = (atexit_py_callback * * )PyMem_Realloc (state -> callbacks , size );
204- if (r == NULL ) {
205- return PyErr_NoMemory ();
206- }
207- state -> callbacks = r ;
188+ if (func_kwargs == NULL )
189+ {
190+ func_kwargs = Py_None ;
208191 }
209-
210- atexit_py_callback * callback = PyMem_Malloc ( sizeof ( atexit_py_callback ));
211- if ( callback == NULL ) {
212- return PyErr_NoMemory () ;
192+ PyObject * callback = PyTuple_Pack ( 3 , func , func_args , func_kwargs );
193+ if ( callback == NULL )
194+ {
195+ return NULL ;
213196 }
214197
215- callback -> args = PyTuple_GetSlice (args , 1 , PyTuple_GET_SIZE (args ));
216- if (callback -> args == NULL ) {
217- PyMem_Free (callback );
198+ struct atexit_state * state = get_atexit_state ();
199+ // atexit callbacks go in a LIFO order
200+ if (PyList_Insert (state -> callbacks , 0 , callback ) < 0 )
201+ {
202+ Py_DECREF (callback );
218203 return NULL ;
219204 }
220- callback -> func = Py_NewRef (func );
221- callback -> kwargs = Py_XNewRef (kwargs );
222-
223- state -> callbacks [state -> ncallbacks ++ ] = callback ;
205+ Py_DECREF (callback );
224206
225207 return Py_NewRef (func );
226208}
@@ -264,7 +246,33 @@ static PyObject *
264246atexit_ncallbacks (PyObject * module , PyObject * unused )
265247{
266248 struct atexit_state * state = get_atexit_state ();
267- return PyLong_FromSsize_t (state -> ncallbacks );
249+ assert (state -> callbacks != NULL );
250+ assert (PyList_CheckExact (state -> callbacks ));
251+ return PyLong_FromSsize_t (PyList_GET_SIZE (state -> callbacks ));
252+ }
253+
254+ static int
255+ atexit_unregister_locked (PyObject * callbacks , PyObject * func )
256+ {
257+ for (Py_ssize_t i = 0 ; i < PyList_GET_SIZE (callbacks ); ++ i ) {
258+ PyObject * tuple = PyList_GET_ITEM (callbacks , i );
259+ assert (PyTuple_CheckExact (tuple ));
260+ PyObject * to_compare = PyTuple_GET_ITEM (tuple , 0 );
261+ int cmp = PyObject_RichCompareBool (func , to_compare , Py_EQ );
262+ if (cmp < 0 )
263+ {
264+ return -1 ;
265+ }
266+ if (cmp == 1 ) {
267+ // We found a callback!
268+ if (PyList_SetSlice (callbacks , i , i + 1 , NULL ) < 0 ) {
269+ return -1 ;
270+ }
271+ -- i ;
272+ }
273+ }
274+
275+ return 0 ;
268276}
269277
270278PyDoc_STRVAR (atexit_unregister__doc__ ,
@@ -280,22 +288,11 @@ static PyObject *
280288atexit_unregister (PyObject * module , PyObject * func )
281289{
282290 struct atexit_state * state = get_atexit_state ();
283- for (int i = 0 ; i < state -> ncallbacks ; i ++ )
284- {
285- atexit_py_callback * cb = state -> callbacks [i ];
286- if (cb == NULL ) {
287- continue ;
288- }
289-
290- int eq = PyObject_RichCompareBool (cb -> func , func , Py_EQ );
291- if (eq < 0 ) {
292- return NULL ;
293- }
294- if (eq ) {
295- atexit_delete_cb (state , i );
296- }
297- }
298- Py_RETURN_NONE ;
291+ int result ;
292+ Py_BEGIN_CRITICAL_SECTION (state -> callbacks );
293+ result = atexit_unregister_locked (state -> callbacks , func );
294+ Py_END_CRITICAL_SECTION ();
295+ return result < 0 ? NULL : Py_None ;
299296}
300297
301298
0 commit comments