@@ -702,84 +702,6 @@ size_t PyMlirContext::getLiveCount() {
702702 return getLiveContexts ().size ();
703703}
704704
705- size_t PyMlirContext::getLiveOperationCount () {
706- nb::ft_lock_guard lock (liveOperationsMutex);
707- return liveOperations.size ();
708- }
709-
710- std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects () {
711- std::vector<PyOperation *> liveObjects;
712- nb::ft_lock_guard lock (liveOperationsMutex);
713- for (auto &entry : liveOperations)
714- liveObjects.push_back (entry.second .second );
715- return liveObjects;
716- }
717-
718- size_t PyMlirContext::clearLiveOperations () {
719-
720- LiveOperationMap operations;
721- {
722- nb::ft_lock_guard lock (liveOperationsMutex);
723- std::swap (operations, liveOperations);
724- }
725- for (auto &op : operations)
726- op.second .second ->setInvalid ();
727- size_t numInvalidated = operations.size ();
728- return numInvalidated;
729- }
730-
731- void PyMlirContext::clearOperation (MlirOperation op) {
732- PyOperation *py_op;
733- {
734- nb::ft_lock_guard lock (liveOperationsMutex);
735- auto it = liveOperations.find (op.ptr );
736- if (it == liveOperations.end ()) {
737- return ;
738- }
739- py_op = it->second .second ;
740- liveOperations.erase (it);
741- }
742- py_op->setInvalid ();
743- }
744-
745- void PyMlirContext::clearOperationsInside (PyOperationBase &op) {
746- typedef struct {
747- PyOperation &rootOp;
748- bool rootSeen;
749- } callBackData;
750- callBackData data{op.getOperation (), false };
751- // Mark all ops below the op that the passmanager will be rooted
752- // at (but not op itself - note the preorder) as invalid.
753- MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
754- void *userData) {
755- callBackData *data = static_cast <callBackData *>(userData);
756- if (LLVM_LIKELY (data->rootSeen ))
757- data->rootOp .getOperation ().getContext ()->clearOperation (op);
758- else
759- data->rootSeen = true ;
760- return MlirWalkResult::MlirWalkResultAdvance;
761- };
762- mlirOperationWalk (op.getOperation (), invalidatingCallback,
763- static_cast <void *>(&data), MlirWalkPreOrder);
764- }
765- void PyMlirContext::clearOperationsInside (MlirOperation op) {
766- PyOperationRef opRef = PyOperation::forOperation (getRef (), op);
767- clearOperationsInside (opRef->getOperation ());
768- }
769-
770- void PyMlirContext::clearOperationAndInside (PyOperationBase &op) {
771- MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
772- void *userData) {
773- PyMlirContextRef &contextRef = *static_cast <PyMlirContextRef *>(userData);
774- contextRef->clearOperation (op);
775- return MlirWalkResult::MlirWalkResultAdvance;
776- };
777- mlirOperationWalk (op.getOperation (), invalidatingCallback,
778- &op.getOperation ().getContext (), MlirWalkPreOrder);
779- }
780-
781- size_t PyMlirContext::getLiveModuleCount () { return liveModules.size (); }
782-
783705nb::object PyMlirContext::contextEnter (nb::object context) {
784706 return PyThreadContextEntry::pushContext (context);
785707}
@@ -1151,38 +1073,20 @@ PyLocation &DefaultingPyLocation::resolve() {
11511073PyModule::PyModule (PyMlirContextRef contextRef, MlirModule module )
11521074 : BaseContextObject(std::move(contextRef)), module (module ) {}
11531075
1154- PyModule::~PyModule () {
1155- nb::gil_scoped_acquire acquire;
1156- auto &liveModules = getContext ()->liveModules ;
1157- assert (liveModules.count (module .ptr ) == 1 &&
1158- " destroying module not in live map" );
1159- liveModules.erase (module .ptr );
1160- mlirModuleDestroy (module );
1161- }
1076+ PyModule::~PyModule () { mlirModuleDestroy (module ); }
11621077
11631078PyModuleRef PyModule::forModule (MlirModule module ) {
11641079 MlirContext context = mlirModuleGetContext (module );
11651080 PyMlirContextRef contextRef = PyMlirContext::forContext (context);
11661081
1167- nb::gil_scoped_acquire acquire;
1168- auto &liveModules = contextRef->liveModules ;
1169- auto it = liveModules.find (module .ptr );
1170- if (it == liveModules.end ()) {
1171- // Create.
1172- PyModule *unownedModule = new PyModule (std::move (contextRef), module );
1173- // Note that the default return value policy on cast is automatic_reference,
1174- // which does not take ownership (delete will not be called).
1175- // Just be explicit.
1176- nb::object pyRef = nb::cast (unownedModule, nb::rv_policy::take_ownership);
1177- unownedModule->handle = pyRef;
1178- liveModules[module .ptr ] =
1179- std::make_pair (unownedModule->handle , unownedModule);
1180- return PyModuleRef (unownedModule, std::move (pyRef));
1181- }
1182- // Use existing.
1183- PyModule *existing = it->second .second ;
1184- nb::object pyRef = nb::borrow<nb::object>(it->second .first );
1185- return PyModuleRef (existing, std::move (pyRef));
1082+ // Create.
1083+ PyModule *unownedModule = new PyModule (std::move (contextRef), module );
1084+ // Note that the default return value policy on cast is automatic_reference,
1085+ // which does not take ownership (delete will not be called).
1086+ // Just be explicit.
1087+ nb::object pyRef = nb::cast (unownedModule, nb::rv_policy::take_ownership);
1088+ unownedModule->handle = pyRef;
1089+ return PyModuleRef (unownedModule, std::move (pyRef));
11861090}
11871091
11881092nb::object PyModule::createFromCapsule (nb::object capsule) {
@@ -1207,16 +1111,8 @@ PyOperation::~PyOperation() {
12071111 // If the operation has already been invalidated there is nothing to do.
12081112 if (!valid)
12091113 return ;
1210-
1211- // Otherwise, invalidate the operation and remove it from live map when it is
1212- // attached.
1213- if (isAttached ()) {
1214- getContext ()->clearOperation (*this );
1215- } else {
1216- // And destroy it when it is detached, i.e. owned by Python, in which case
1217- // all nested operations must be invalidated at removed from the live map as
1218- // well.
1219- erase ();
1114+ if (!isAttached ()) {
1115+ mlirOperationDestroy (operation);
12201116 }
12211117}
12221118
@@ -1246,41 +1142,22 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
12461142 if (parentKeepAlive) {
12471143 unownedOperation->parentKeepAlive = std::move (parentKeepAlive);
12481144 }
1249- return unownedOperation;
1145+ return PyOperationRef ( unownedOperation, std::move (pyRef)) ;
12501146}
12511147
12521148PyOperationRef PyOperation::forOperation (PyMlirContextRef contextRef,
12531149 MlirOperation operation,
12541150 nb::object parentKeepAlive) {
1255- nb::ft_lock_guard lock (contextRef->liveOperationsMutex );
1256- auto &liveOperations = contextRef->liveOperations ;
1257- auto it = liveOperations.find (operation.ptr );
1258- if (it == liveOperations.end ()) {
1259- // Create.
1260- PyOperationRef result = createInstance (std::move (contextRef), operation,
1261- std::move (parentKeepAlive));
1262- liveOperations[operation.ptr ] =
1263- std::make_pair (result.getObject (), result.get ());
1264- return result;
1265- }
1266- // Use existing.
1267- PyOperation *existing = it->second .second ;
1268- nb::object pyRef = nb::borrow<nb::object>(it->second .first );
1269- return PyOperationRef (existing, std::move (pyRef));
1151+ // Create.
1152+ return createInstance (std::move (contextRef), operation,
1153+ std::move (parentKeepAlive));
12701154}
12711155
12721156PyOperationRef PyOperation::createDetached (PyMlirContextRef contextRef,
12731157 MlirOperation operation,
12741158 nb::object parentKeepAlive) {
1275- nb::ft_lock_guard lock (contextRef->liveOperationsMutex );
1276- auto &liveOperations = contextRef->liveOperations ;
1277- assert (liveOperations.count (operation.ptr ) == 0 &&
1278- " cannot create detached operation that already exists" );
1279- (void )liveOperations;
12801159 PyOperationRef created = createInstance (std::move (contextRef), operation,
12811160 std::move (parentKeepAlive));
1282- liveOperations[operation.ptr ] =
1283- std::make_pair (created.getObject (), created.get ());
12841161 created->attached = false ;
12851162 return created;
12861163}
@@ -1652,7 +1529,6 @@ nb::object PyOperation::createOpView() {
16521529
16531530void PyOperation::erase () {
16541531 checkValid ();
1655- getContext ()->clearOperationAndInside (*this );
16561532 mlirOperationDestroy (operation);
16571533}
16581534
@@ -2494,7 +2370,6 @@ class PyBlockArgumentList
24942370 : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
24952371public:
24962372 static constexpr const char *pyClassName = " BlockArgumentList" ;
2497- using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
24982373
24992374 PyBlockArgumentList (PyOperationRef operation, MlirBlock block,
25002375 intptr_t startIndex = 0 , intptr_t length = -1 ,
@@ -3023,14 +2898,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
30232898 PyMlirContextRef ref = PyMlirContext::forContext (self.get ());
30242899 return ref.releaseObject ();
30252900 })
3026- .def (" _get_live_operation_count" , &PyMlirContext::getLiveOperationCount)
3027- .def (" _get_live_operation_objects" ,
3028- &PyMlirContext::getLiveOperationObjects)
3029- .def (" _clear_live_operations" , &PyMlirContext::clearLiveOperations)
3030- .def (" _clear_live_operations_inside" ,
3031- nb::overload_cast<MlirOperation>(
3032- &PyMlirContext::clearOperationsInside))
3033- .def (" _get_live_module_count" , &PyMlirContext::getLiveModuleCount)
30342901 .def_prop_ro (MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
30352902 .def (MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
30362903 .def (" __enter__" , &PyMlirContext::contextEnter)
@@ -3428,7 +3295,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34283295 // Defer to the operation's __str__.
34293296 return self.attr (" operation" ).attr (" __str__" )();
34303297 },
3431- kOperationStrDunderDocstring );
3298+ kOperationStrDunderDocstring )
3299+ .def (
3300+ " __eq__" ,
3301+ [](PyModule &self, PyModule &other) {
3302+ return mlirModuleEqual (self.get (), other.get ());
3303+ },
3304+ " other" _a);
34323305
34333306 // ----------------------------------------------------------------------------
34343307 // Mapping of Operation.
@@ -3440,7 +3313,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34403313 })
34413314 .def (" __eq__" ,
34423315 [](PyOperationBase &self, PyOperationBase &other) {
3443- return &self.getOperation () == &other.getOperation ();
3316+ return mlirOperationEqual (self.getOperation ().get (),
3317+ other.getOperation ().get ());
34443318 })
34453319 .def (" __eq__" ,
34463320 [](PyOperationBase &self, nb::object other) { return false ; })
0 commit comments