Skip to content

[BUG]: Race condition between wrapper deallocation and lookup #864

@colesbury

Description

@colesbury

Problem description

There's a race condition between wrapper lookup and wrapper deallocation where a Python wrapper may be returned that's in the process of being deallocated. I have a reproducer for the free threading build. I think the problem can also affect the default (GIL-enabled) build as well, but I don't have a reproducer yet.

This the counterpart of the pybind11 bug:

Explanation

nb_type_put will lookup an existing Python wrapper object in inst_c2p. The inst_dealloc function removes the wrapper from inst2_cp when the wrapper is deallocated.

During the nb_type_put call, it's possible that the found wrapper has a reference count of 0 and is in the process of being deallocated, but not yet removed from inst2_cp. In the free threading build, this can happen because nb_type_put can be run concurrently with inst_dealloc up to the acquisition of the shard lock. I think this can also happen in the default (GIL-enabled) build, because things like Py_CLEAR(*dict) can call arbitrary code that may temporarily release the GIL.

Suggested fix

nb_type_put should only incref and return a wrapper if the reference count is not zero. In the GIL-enabled build, this is roughly:

if (Py_REFCNT(seq.inst) > 0) {
    Py_INCREF(seq.inst);
    return seq.inst;
}

In the free threading build, we'll want to use PyUnstable_TryIncref when it's available, or implement that logical like we're doing in pybind11.

See also

Reproducible example code

struct Data {};
Data MyData;

NB_MODULE(my_ext, m) {
     nb::class_<Data>(m, "Data")
        .def_prop_ro_static("MyData", [](nb::handle /*unused*/) { return &MyData ; }, nb::rv_policy::reference);
}
import threading
import my_ext


if __name__ == "__main__":
    num_workers = 2

    barrier = threading.Barrier(num_workers)

    def closure():
        barrier.wait()

        for _ in range(1000):
            my_ext.Data.MyData
    
    threads = [threading.Thread(target=closure) for _ in range(num_workers)]
    for t in threads:
        t.start()
    for t in threads:
        t.join()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions