-
Notifications
You must be signed in to change notification settings - Fork 259
Description
Problem description
There's a race condition between wrapper lookup and wrapper deallocation where a Python wrapper may be returned that's in the process of being deallocated. I have a reproducer for the free threading build. I think the problem can also affect the default (GIL-enabled) build as well, but I don't have a reproducer yet.
This the counterpart of the pybind11 bug:
Explanation
nb_type_put will lookup an existing Python wrapper object in inst_c2p. The inst_dealloc function removes the wrapper from inst2_cp when the wrapper is deallocated.
During the nb_type_put call, it's possible that the found wrapper has a reference count of 0 and is in the process of being deallocated, but not yet removed from inst2_cp. In the free threading build, this can happen because nb_type_put can be run concurrently with inst_dealloc up to the acquisition of the shard lock. I think this can also happen in the default (GIL-enabled) build, because things like Py_CLEAR(*dict) can call arbitrary code that may temporarily release the GIL.
Suggested fix
nb_type_put should only incref and return a wrapper if the reference count is not zero. In the GIL-enabled build, this is roughly:
if (Py_REFCNT(seq.inst) > 0) {
Py_INCREF(seq.inst);
return seq.inst;
}In the free threading build, we'll want to use PyUnstable_TryIncref when it's available, or implement that logical like we're doing in pybind11.
See also
- Make
_Py_TryIncrefpublic as an unstable API asPyUnstable_TryIncref()python/cpython#128844 - fix(free-threading): fix data race when using shared variables pybind/pybind11#5494
Reproducible example code
struct Data {};
Data MyData;
NB_MODULE(my_ext, m) {
nb::class_<Data>(m, "Data")
.def_prop_ro_static("MyData", [](nb::handle /*unused*/) { return &MyData ; }, nb::rv_policy::reference);
}import threading
import my_ext
if __name__ == "__main__":
num_workers = 2
barrier = threading.Barrier(num_workers)
def closure():
barrier.wait()
for _ in range(1000):
my_ext.Data.MyData
threads = [threading.Thread(target=closure) for _ in range(num_workers)]
for t in threads:
t.start()
for t in threads:
t.join()