From d3585d65daba487de2626071c12a7a9d79d6621f Mon Sep 17 00:00:00 2001 From: makslevental Date: Mon, 22 Sep 2025 18:57:51 -0700 Subject: [PATCH] [MLIR][Python] use nb::typed for return signatures --- mlir/lib/Bindings/Python/IRAffine.cpp | 49 +++--- mlir/lib/Bindings/Python/IRAttributes.cpp | 50 +++--- mlir/lib/Bindings/Python/IRCore.cpp | 200 ++++++++++++++-------- mlir/lib/Bindings/Python/IRInterfaces.cpp | 26 ++- mlir/lib/Bindings/Python/IRModule.h | 14 +- mlir/lib/Bindings/Python/IRTypes.cpp | 21 +-- mlir/test/python/dialects/python_test.py | 3 + 7 files changed, 214 insertions(+), 149 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index bc6aa0dac6221..7147f2cbad149 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -574,7 +574,9 @@ void mlir::python::populateIRAffine(nb::module_ &m) { }) .def_prop_ro( "context", - [](PyAffineExpr &self) { return self.getContext().getObject(); }) + [](PyAffineExpr &self) -> nb::typed { + return self.getContext().getObject(); + }) .def("compose", [](PyAffineExpr &self, PyAffineMap &other) { return PyAffineExpr(self.getContext(), @@ -706,28 +708,29 @@ void mlir::python::populateIRAffine(nb::module_ &m) { [](PyAffineMap &self) { return static_cast(llvm::hash_value(self.get().ptr)); }) - .def_static("compress_unused_symbols", - [](const nb::list &affineMaps, - DefaultingPyMlirContext context) { - SmallVector maps; - pyListToVector( - affineMaps, maps, "attempting to create an AffineMap"); - std::vector compressed(affineMaps.size()); - auto populate = [](void *result, intptr_t idx, - MlirAffineMap m) { - static_cast(result)[idx] = (m); - }; - mlirAffineMapCompressUnusedSymbols( - maps.data(), maps.size(), compressed.data(), populate); - std::vector res; - res.reserve(compressed.size()); - for (auto m : compressed) - res.emplace_back(context->getRef(), m); - return res; - }) + .def_static( + "compress_unused_symbols", + [](const nb::list &affineMaps, DefaultingPyMlirContext context) { + SmallVector maps; + pyListToVector( + affineMaps, maps, "attempting to create an AffineMap"); + std::vector compressed(affineMaps.size()); + auto populate = [](void *result, intptr_t idx, MlirAffineMap m) { + static_cast(result)[idx] = (m); + }; + mlirAffineMapCompressUnusedSymbols(maps.data(), maps.size(), + compressed.data(), populate); + std::vector res; + res.reserve(compressed.size()); + for (auto m : compressed) + res.emplace_back(context->getRef(), m); + return res; + }) .def_prop_ro( "context", - [](PyAffineMap &self) { return self.getContext().getObject(); }, + [](PyAffineMap &self) -> nb::typed { + return self.getContext().getObject(); + }, "Context that owns the Affine Map") .def( "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, @@ -893,7 +896,9 @@ void mlir::python::populateIRAffine(nb::module_ &m) { }) .def_prop_ro( "context", - [](PyIntegerSet &self) { return self.getContext().getObject(); }) + [](PyIntegerSet &self) -> nb::typed { + return self.getContext().getObject(); + }) .def( "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, kDumpDocstring) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 212228fbac91e..c77653f97e6dd 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -485,7 +485,7 @@ class PyArrayAttribute : public PyConcreteAttribute { PyArrayAttributeIterator &dunderIter() { return *this; } - nb::object dunderNext() { + nb::typed dunderNext() { // TODO: Throw is an inefficient way to stop iteration. if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) throw nb::stop_iteration(); @@ -526,7 +526,8 @@ class PyArrayAttribute : public PyConcreteAttribute { "Gets a uniqued Array attribute"); c.def( "__getitem__", - [](PyArrayAttribute &arr, intptr_t i) { + [](PyArrayAttribute &arr, + intptr_t i) -> nb::typed { if (i >= mlirArrayAttrGetNumElements(arr)) throw nb::index_error("ArrayAttribute index out of range"); return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast(); @@ -1010,14 +1011,16 @@ class PyDenseElementsAttribute [](PyDenseElementsAttribute &self) -> bool { return mlirDenseElementsAttrIsSplat(self); }) - .def("get_splat_value", [](PyDenseElementsAttribute &self) { - if (!mlirDenseElementsAttrIsSplat(self)) - throw nb::value_error( - "get_splat_value called on a non-splat attribute"); - return PyAttribute(self.getContext(), - mlirDenseElementsAttrGetSplatValue(self)) - .maybeDownCast(); - }); + .def("get_splat_value", + [](PyDenseElementsAttribute &self) + -> nb::typed { + if (!mlirDenseElementsAttrIsSplat(self)) + throw nb::value_error( + "get_splat_value called on a non-splat attribute"); + return PyAttribute(self.getContext(), + mlirDenseElementsAttrGetSplatValue(self)) + .maybeDownCast(); + }); } static PyType_Slot slots[]; @@ -1332,7 +1335,7 @@ class PyDenseIntElementsAttribute /// Returns the element at the given linear position. Asserts if the index /// is out of range. - nb::object dunderGetItem(intptr_t pos) { + nb::int_ dunderGetItem(intptr_t pos) { if (pos < 0 || pos >= dunderLen()) { throw nb::index_error("attempt to access out of bounds element"); } @@ -1522,13 +1525,15 @@ class PyDictAttribute : public PyConcreteAttribute { }, nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(), "Gets an uniqued dict attribute"); - c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { - MlirAttribute attr = - mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); - if (mlirAttributeIsNull(attr)) - throw nb::key_error("attempt to access a non-existent attribute"); - return PyAttribute(self.getContext(), attr).maybeDownCast(); - }); + c.def("__getitem__", + [](PyDictAttribute &self, + const std::string &name) -> nb::typed { + MlirAttribute attr = + mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); + if (mlirAttributeIsNull(attr)) + throw nb::key_error("attempt to access a non-existent attribute"); + return PyAttribute(self.getContext(), attr).maybeDownCast(); + }); c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { if (index < 0 || index >= self.dunderLen()) { throw nb::index_error("attempt to access out of bounds attribute"); @@ -1594,10 +1599,11 @@ class PyTypeAttribute : public PyConcreteAttribute { }, nb::arg("value"), nb::arg("context") = nb::none(), "Gets a uniqued Type attribute"); - c.def_prop_ro("value", [](PyTypeAttribute &self) { - return PyType(self.getContext(), mlirTypeAttrGetValue(self.get())) - .maybeDownCast(); - }); + c.def_prop_ro( + "value", [](PyTypeAttribute &self) -> nb::typed { + return PyType(self.getContext(), mlirTypeAttrGetValue(self.get())) + .maybeDownCast(); + }); } }; diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 4b238e11c7fff..83a8757bb72c7 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -513,7 +513,7 @@ class PyOperationIterator { PyOperationIterator &dunderIter() { return *this; } - nb::object dunderNext() { + nb::typed dunderNext() { parentOperation->checkValid(); if (mlirOperationIsNull(next)) { throw nb::stop_iteration(); @@ -562,7 +562,7 @@ class PyOperationList { return count; } - nb::object dunderGetItem(intptr_t index) { + nb::typed dunderGetItem(intptr_t index) { parentOperation->checkValid(); if (index < 0) { index += dunderLen(); @@ -725,7 +725,7 @@ nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) { new PyDiagnosticHandler(get(), std::move(callback)); nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::take_ownership); - pyHandlerObject.inc_ref(); + (void)pyHandlerObject.inc_ref(); // In these C callbacks, the userData is a PyDiagnosticHandler* that is // guaranteed to be known to pybind. @@ -1395,7 +1395,7 @@ nb::object PyOperation::getCapsule() { return nb::steal(mlirPythonOperationToCapsule(get())); } -nb::object PyOperation::createFromCapsule(nb::object capsule) { +nb::object PyOperation::createFromCapsule(const nb::object &capsule) { MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); if (mlirOperationIsNull(rawOperation)) throw nb::python_error(); @@ -1605,7 +1605,9 @@ class PyConcreteValue : public PyValue { }, nb::arg("other_value")); cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, - [](DerivedTy &self) { return self.maybeDownCast(); }); + [](DerivedTy &self) -> nb::typed { + return self.maybeDownCast(); + }); DerivedTy::bindDerived(cls); } @@ -1623,13 +1625,14 @@ class PyOpResult : public PyConcreteValue { using PyConcreteValue::PyConcreteValue; static void bindDerived(ClassTy &c) { - c.def_prop_ro("owner", [](PyOpResult &self) { - assert( - mlirOperationEqual(self.getParentOperation()->get(), - mlirOpResultGetOwner(self.get())) && - "expected the owner of the value in Python to match that in the IR"); - return self.getParentOperation().getObject(); - }); + c.def_prop_ro( + "owner", [](PyOpResult &self) -> nb::typed { + assert(mlirOperationEqual(self.getParentOperation()->get(), + mlirOpResultGetOwner(self.get())) && + "expected the owner of the value in Python to match that in " + "the IR"); + return self.getParentOperation().getObject(); + }); c.def_prop_ro("result_number", [](PyOpResult &self) { return mlirOpResultGetResultNumber(self.get()); }); @@ -1638,9 +1641,9 @@ class PyOpResult : public PyConcreteValue { /// Returns the list of types of the values held by container. template -static std::vector getValueTypes(Container &container, - PyMlirContextRef &context) { - std::vector result; +static std::vector> +getValueTypes(Container &container, PyMlirContextRef &context) { + std::vector> result; result.reserve(container.size()); for (int i = 0, e = container.size(); i < e; ++i) { result.push_back(PyType(context->getRef(), @@ -1671,9 +1674,10 @@ class PyOpResultList : public Sliceable { c.def_prop_ro("types", [](PyOpResultList &self) { return getValueTypes(self, self.operation->getContext()); }); - c.def_prop_ro("owner", [](PyOpResultList &self) { - return self.operation->createOpView(); - }); + c.def_prop_ro("owner", + [](PyOpResultList &self) -> nb::typed { + return self.operation->createOpView(); + }); } PyOperationRef &getOperation() { return operation; } @@ -2104,7 +2108,7 @@ PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) { size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) { - return PyThreadContextEntry::pushInsertionPoint(insertPoint); + return PyThreadContextEntry::pushInsertionPoint(std::move(insertPoint)); } void PyInsertionPoint::contextExit(const nb::object &excType, @@ -2125,7 +2129,7 @@ nb::object PyAttribute::getCapsule() { return nb::steal(mlirPythonAttributeToCapsule(*this)); } -PyAttribute PyAttribute::createFromCapsule(nb::object capsule) { +PyAttribute PyAttribute::createFromCapsule(const nb::object &capsule) { MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); if (mlirAttributeIsNull(rawAttr)) throw nb::python_error(); @@ -2677,7 +2681,8 @@ class PyOpAttributeMap { PyOpAttributeMap(PyOperationRef operation) : operation(std::move(operation)) {} - nb::object dunderGetItemNamed(const std::string &name) { + nb::typed + dunderGetItemNamed(const std::string &name) { MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) { @@ -2962,24 +2967,27 @@ void mlir::python::populateIRCore(nb::module_ &m) { }) .def_static("_get_live_count", &PyMlirContext::getLiveCount) .def("_get_context_again", - [](PyMlirContext &self) { + [](PyMlirContext &self) -> nb::typed { PyMlirContextRef ref = PyMlirContext::forContext(self.get()); return ref.releaseObject(); }) .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, + &PyMlirContext::createFromCapsule) .def("__enter__", &PyMlirContext::contextEnter) .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(), nb::arg("exc_value").none(), nb::arg("traceback").none()) .def_prop_ro_static( "current", - [](nb::object & /*class*/) { + [](nb::object & /*class*/) + -> std::optional> { auto *context = PyThreadContextEntry::getDefaultContext(); if (!context) - return nb::none(); + return {}; return nb::cast(context); }, + nb::sig("def current(/) -> Context | None"), "Gets the Context bound to the current thread or raises ValueError") .def_prop_ro( "dialects", @@ -3123,7 +3131,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- nb::class_(m, "DialectRegistry") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, + &PyDialectRegistry::createFromCapsule) .def(nb::init<>()); //---------------------------------------------------------------------------- @@ -3131,7 +3140,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- nb::class_(m, "Location") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) .def("__enter__", &PyLocation::contextEnter) .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(), nb::arg("exc_value").none(), nb::arg("traceback").none()) @@ -3286,7 +3295,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { "Gets a Location from a LocationAttr") .def_prop_ro( "context", - [](PyLocation &self) { return self.getContext().getObject(); }, + [](PyLocation &self) -> nb::typed { + return self.getContext().getObject(); + }, "Context that owns the Location") .def_prop_ro( "attr", @@ -3313,12 +3324,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- nb::class_(m, "Module", nb::is_weak_referenceable()) .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule, - kModuleCAPICreate) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule, + kModuleCAPICreate) .def("_clear_mlir_module", &PyModule::clearMlirModule) .def_static( "parse", - [](const std::string &moduleAsm, DefaultingPyMlirContext context) { + [](const std::string &moduleAsm, DefaultingPyMlirContext context) + -> nb::typed { PyMlirContext::ErrorCapture errors(context->getRef()); MlirModule module = mlirModuleCreateParse( context->get(), toMlirStringRef(moduleAsm)); @@ -3330,7 +3342,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { kModuleParseDocstring) .def_static( "parse", - [](nb::bytes moduleAsm, DefaultingPyMlirContext context) { + [](nb::bytes moduleAsm, DefaultingPyMlirContext context) + -> nb::typed { PyMlirContext::ErrorCapture errors(context->getRef()); MlirModule module = mlirModuleCreateParse( context->get(), toMlirStringRef(moduleAsm)); @@ -3342,7 +3355,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { kModuleParseDocstring) .def_static( "parseFile", - [](const std::string &path, DefaultingPyMlirContext context) { + [](const std::string &path, DefaultingPyMlirContext context) + -> nb::typed { PyMlirContext::ErrorCapture errors(context->getRef()); MlirModule module = mlirModuleCreateParseFromFile( context->get(), toMlirStringRef(path)); @@ -3354,7 +3368,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { kModuleParseDocstring) .def_static( "create", - [](const std::optional &loc) { + [](const std::optional &loc) + -> nb::typed { PyLocation pyLoc = maybeGetTracebackLocation(loc); MlirModule module = mlirModuleCreateEmpty(pyLoc.get()); return PyModule::forModule(module).releaseObject(); @@ -3362,11 +3377,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("loc") = nb::none(), "Creates an empty module") .def_prop_ro( "context", - [](PyModule &self) { return self.getContext().getObject(); }, + [](PyModule &self) -> nb::typed { + return self.getContext().getObject(); + }, "Context that created the Module") .def_prop_ro( "operation", - [](PyModule &self) { + [](PyModule &self) -> nb::typed { return PyOperation::forOperation(self.getContext(), mlirModuleGetOperation(self.get()), self.getRef().releaseObject()) @@ -3430,7 +3447,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { }) .def_prop_ro( "context", - [](PyOperationBase &self) { + [](PyOperationBase &self) -> nb::typed { PyOperation &concreteOperation = self.getOperation(); concreteOperation.checkValid(); return concreteOperation.getContext().getObject(); @@ -3461,7 +3478,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { "Returns the list of Operation results.") .def_prop_ro( "result", - [](PyOperationBase &self) { + [](PyOperationBase &self) -> nb::typed { auto &operation = self.getOperation(); return PyOpResult(operation.getRef(), getUniqueResult(operation)) .maybeDownCast(); @@ -3478,11 +3495,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { "Returns the source location the operation was defined or derived " "from.") .def_prop_ro("parent", - [](PyOperationBase &self) -> nb::object { + [](PyOperationBase &self) + -> std::optional> { auto parent = self.getOperation().getParentOperation(); if (parent) return parent->getObject(); - return nb::none(); + return {}; }) .def( "__str__", @@ -3553,13 +3571,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { "of the parent block.") .def( "clone", - [](PyOperationBase &self, nb::object ip) { + [](PyOperationBase &self, + const nb::object &ip) -> nb::typed { return self.getOperation().clone(ip); }, nb::arg("ip") = nb::none()) .def( "detach_from_parent", - [](PyOperationBase &self) { + [](PyOperationBase &self) -> nb::typed { PyOperation &operation = self.getOperation(); operation.checkValid(); if (!operation.isAttached()) @@ -3595,7 +3614,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { std::optional attributes, std::optional> successors, int regions, const std::optional &location, - const nb::object &maybeIp, bool inferType) { + const nb::object &maybeIp, + bool inferType) -> nb::typed { // Unpack/validate operands. llvm::SmallVector mlirOperands; if (operands) { @@ -3620,7 +3640,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def_static( "parse", [](const std::string &sourceStr, const std::string &sourceName, - DefaultingPyMlirContext context) { + DefaultingPyMlirContext context) + -> nb::typed { return PyOperation::parse(context->getRef(), sourceStr, sourceName) ->createOpView(); }, @@ -3629,9 +3650,16 @@ void mlir::python::populateIRCore(nb::module_ &m) { "Parses an operation. Supports both text assembly format and binary " "bytecode format.") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) - .def_prop_ro("operation", [](nb::object self) { return self; }) - .def_prop_ro("opview", &PyOperation::createOpView) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, + &PyOperation::createFromCapsule) + .def_prop_ro("operation", + [](nb::object self) -> nb::typed { + return self; + }) + .def_prop_ro("opview", + [](PyOperation &self) -> nb::typed { + return self.createOpView(); + }) .def_prop_ro("block", &PyOperation::getBlock) .def_prop_ro( "successors", @@ -3644,7 +3672,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { auto opViewClass = nb::class_(m, "OpView") - .def(nb::init(), nb::arg("operation")) + .def(nb::init>(), + nb::arg("operation")) .def( "__init__", [](PyOpView *self, std::string_view name, @@ -3671,9 +3700,15 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("successors") = nb::none(), nb::arg("regions") = nb::none(), nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none()) - - .def_prop_ro("operation", &PyOpView::getOperationObject) - .def_prop_ro("opview", [](nb::object self) { return self; }) + .def_prop_ro( + "operation", + [](PyOpView &self) -> nb::typed { + return self.getOperationObject(); + }) + .def_prop_ro("opview", + [](nb::object self) -> nb::typed { + return self; + }) .def( "__str__", [](PyOpView &self) { return nb::str(self.getOperationObject()); }) @@ -3717,7 +3752,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { "Builds a specific, generated OpView based on class level attributes."); opViewClass.attr("parse") = classmethod( [](const nb::object &cls, const std::string &sourceStr, - const std::string &sourceName, DefaultingPyMlirContext context) { + const std::string &sourceName, + DefaultingPyMlirContext context) -> nb::typed { PyOperationRef parsed = PyOperation::parse(context->getRef(), sourceStr, sourceName); @@ -3752,7 +3788,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { "Returns a forward-optimized sequence of blocks.") .def_prop_ro( "owner", - [](PyRegion &self) { + [](PyRegion &self) -> nb::typed { return self.getParentOperation()->createOpView(); }, "Returns the operation owning this region.") @@ -3777,7 +3813,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule) .def_prop_ro( "owner", - [](PyBlock &self) { + [](PyBlock &self) -> nb::typed { return self.getParentOperation()->createOpView(); }, "Returns the owning operation of this block.") @@ -3960,11 +3996,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { "Returns the block that this InsertionPoint points to.") .def_prop_ro( "ref_operation", - [](PyInsertionPoint &self) -> nb::object { + [](PyInsertionPoint &self) + -> std::optional> { auto refOperation = self.getRefOperation(); if (refOperation) return refOperation->getObject(); - return nb::none(); + return {}; }, "The reference operation before which new operations are " "inserted, or None if the insertion point is at the end of " @@ -3979,10 +4016,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def(nb::init(), nb::arg("cast_from_type"), "Casts the passed attribute to the generic Attribute") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, + &PyAttribute::createFromCapsule) .def_static( "parse", - [](const std::string &attrSpec, DefaultingPyMlirContext context) { + [](const std::string &attrSpec, DefaultingPyMlirContext context) + -> nb::typed { PyMlirContext::ErrorCapture errors(context->getRef()); MlirAttribute attr = mlirAttributeParseGet( context->get(), toMlirStringRef(attrSpec)); @@ -3995,10 +4034,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { "failure.") .def_prop_ro( "context", - [](PyAttribute &self) { return self.getContext().getObject(); }, + [](PyAttribute &self) -> nb::typed { + return self.getContext().getObject(); + }, "Context that owns the Attribute") .def_prop_ro("type", - [](PyAttribute &self) { + [](PyAttribute &self) -> nb::typed { return PyType(self.getContext(), mlirAttributeGetType(self)) .maybeDownCast(); @@ -4049,7 +4090,10 @@ void mlir::python::populateIRCore(nb::module_ &m) { "mlirTypeID was expected to be non-null."); return PyTypeID(mlirTypeID); }) - .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyAttribute::maybeDownCast); + .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](PyAttribute &self) -> nb::typed { + return self.maybeDownCast(); + }); //---------------------------------------------------------------------------- // Mapping of PyNamedAttribute @@ -4091,10 +4135,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def(nb::init(), nb::arg("cast_from_type"), "Casts the passed type to the generic Type") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) .def_static( "parse", - [](std::string typeSpec, DefaultingPyMlirContext context) { + [](std::string typeSpec, + DefaultingPyMlirContext context) -> nb::typed { PyMlirContext::ErrorCapture errors(context->getRef()); MlirType type = mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); @@ -4105,7 +4150,10 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("asm"), nb::arg("context") = nb::none(), kContextParseTypeDocstring) .def_prop_ro( - "context", [](PyType &self) { return self.getContext().getObject(); }, + "context", + [](PyType &self) -> nb::typed { + return self.getContext().getObject(); + }, "Context that owns the Type") .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) .def( @@ -4139,7 +4187,10 @@ void mlir::python::populateIRCore(nb::module_ &m) { printAccum.parts.append(")"); return printAccum.join(); }) - .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyType::maybeDownCast) + .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](PyType &self) -> nb::typed { + return self.maybeDownCast(); + }) .def_prop_ro("typeid", [](PyType &self) { MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); if (!mlirTypeIDIsNull(mlirTypeID)) @@ -4154,7 +4205,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- nb::class_(m, "TypeID") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) // Note, this tests whether the underlying TypeIDs are the same, // not whether the wrapper MlirTypeIDs are the same, nor whether // the Python objects are the same (i.e., PyTypeID is a value type). @@ -4175,10 +4226,10 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::class_(m, "Value") .def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")) .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) .def_prop_ro( "context", - [](PyValue &self) { + [](PyValue &self) -> nb::typed { return self.getParentOperation()->getContext().getObject(); }, "Context in which the value lives.") @@ -4266,7 +4317,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("state"), kGetNameAsOperand) .def_prop_ro("type", - [](PyValue &self) { + [](PyValue &self) -> nb::typed { return PyType(self.getParentOperation()->getContext(), mlirValueGetType(self.get())) .maybeDownCast(); @@ -4332,7 +4383,10 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("with_"), nb::arg("exceptions"), kValueReplaceAllUsesExceptDocstring) - .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyValue::maybeDownCast) + .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](PyValue &self) -> nb::typed { + return self.maybeDownCast(); + }) .def_prop_ro( "location", [](MlirValue self) { @@ -4357,7 +4411,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- nb::class_(m, "SymbolTable") .def(nb::init()) - .def("__getitem__", &PySymbolTable::dunderGetItem) + .def("__getitem__", + [](PySymbolTable &self, + const std::string &name) -> nb::typed { + return self.dunderGetItem(name); + }) .def("insert", &PySymbolTable::insert, nb::arg("operation")) .def("erase", &PySymbolTable::erase, nb::arg("operation")) .def("__delitem__", &PySymbolTable::dunderDel) diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index 44aad10ded082..31d4798ffb906 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -212,22 +212,18 @@ class PyConcreteOpInterface { /// Returns the operation instance from which this object was constructed. /// Throws a type error if this object was constructed from a subclass of /// OpView. - nb::object getOperationObject() { - if (operation == nullptr) { + nb::typed getOperationObject() { + if (operation == nullptr) throw nb::type_error("Cannot get an operation from a static interface"); - } - return operation->getRef().releaseObject(); } /// Returns the opview of the operation instance from which this object was /// constructed. Throws a type error if this object was constructed form a /// subclass of OpView. - nb::object getOpView() { - if (operation == nullptr) { + nb::typed getOpView() { + if (operation == nullptr) throw nb::type_error("Cannot get an opview from a static interface"); - } - return operation->createOpView(); } @@ -362,10 +358,9 @@ class PyShapedTypeComponents { "Returns whether the given shaped type component is ranked.") .def_prop_ro( "rank", - [](PyShapedTypeComponents &self) -> nb::object { - if (!self.ranked) { - return nb::none(); - } + [](PyShapedTypeComponents &self) -> std::optional { + if (!self.ranked) + return {}; return nb::int_(self.shape.size()); }, "Returns the rank of the given ranked shaped type components. If " @@ -373,10 +368,9 @@ class PyShapedTypeComponents { "returned.") .def_prop_ro( "shape", - [](PyShapedTypeComponents &self) -> nb::object { - if (!self.ranked) { - return nb::none(); - } + [](PyShapedTypeComponents &self) -> std::optional { + if (!self.ranked) + return {}; return nb::list(self.shape); }, "Returns the shape of the ranked shaped type components as a list " diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 6e97c00d478f1..598ae0188464a 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -671,7 +671,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// Creates a PyOperation from the MlirOperation wrapped by a capsule. /// Ownership of the underlying MlirOperation is taken by calling this /// function. - static nanobind::object createFromCapsule(nanobind::object capsule); + static nanobind::object createFromCapsule(const nanobind::object &capsule); /// Creates an operation. See corresponding python docstring. static nanobind::object @@ -1020,7 +1020,7 @@ class PyAttribute : public BaseContextObject { /// Note that PyAttribute instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirAttribute /// is taken by calling this function. - static PyAttribute createFromCapsule(nanobind::object capsule); + static PyAttribute createFromCapsule(const nanobind::object &capsule); nanobind::object maybeDownCast(); @@ -1101,10 +1101,12 @@ class PyConcreteAttribute : public BaseTy { return DerivedTy::isaFunction(otherAttr); }, nanobind::arg("other")); - cls.def_prop_ro("type", [](PyAttribute &attr) { - return PyType(attr.getContext(), mlirAttributeGetType(attr)) - .maybeDownCast(); - }); + cls.def_prop_ro( + "type", + [](PyAttribute &attr) -> nanobind::typed { + return PyType(attr.getContext(), mlirAttributeGetType(attr)) + .maybeDownCast(); + }); cls.def_prop_ro_static( "static_typeid", [](nanobind::object & /*class*/) -> PyTypeID { diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index cab3bf549295b..07dc00521833f 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -501,7 +501,7 @@ class PyComplexType : public PyConcreteType { "Create a complex type"); c.def_prop_ro( "element_type", - [](PyComplexType &self) { + [](PyComplexType &self) -> nb::typed { return PyType(self.getContext(), mlirComplexTypeGetElementType(self)) .maybeDownCast(); }, @@ -515,7 +515,7 @@ class PyComplexType : public PyConcreteType { void mlir::PyShapedType::bindDerived(ClassTy &c) { c.def_prop_ro( "element_type", - [](PyShapedType &self) { + [](PyShapedType &self) -> nb::typed { return PyType(self.getContext(), mlirShapedTypeGetElementType(self)) .maybeDownCast(); }, @@ -731,8 +731,7 @@ class PyRankedTensorType MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); if (mlirAttributeIsNull(encoding)) return std::nullopt; - return nb::cast>( - PyAttribute(self.getContext(), encoding).maybeDownCast()); + return PyAttribute(self.getContext(), encoding).maybeDownCast(); }); } }; @@ -794,9 +793,9 @@ class PyMemRefType : public PyConcreteType { .def_prop_ro( "layout", [](PyMemRefType &self) -> nb::typed { - return nb::cast>( - PyAttribute(self.getContext(), mlirMemRefTypeGetLayout(self)) - .maybeDownCast()); + return PyAttribute(self.getContext(), + mlirMemRefTypeGetLayout(self)) + .maybeDownCast(); }, "The layout of the MemRef type.") .def( @@ -825,8 +824,7 @@ class PyMemRefType : public PyConcreteType { MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); if (mlirAttributeIsNull(a)) return std::nullopt; - return nb::cast>( - PyAttribute(self.getContext(), a).maybeDownCast()); + return PyAttribute(self.getContext(), a).maybeDownCast(); }, "Returns the memory space of the given MemRef type."); } @@ -867,8 +865,7 @@ class PyUnrankedMemRefType MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self); if (mlirAttributeIsNull(a)) return std::nullopt; - return nb::cast>( - PyAttribute(self.getContext(), a).maybeDownCast()); + return PyAttribute(self.getContext(), a).maybeDownCast(); }, "Returns the memory space of the given Unranked MemRef type."); } @@ -912,7 +909,7 @@ class PyTupleType : public PyConcreteType { "Create a tuple type"); c.def( "get_type", - [](PyTupleType &self, intptr_t pos) { + [](PyTupleType &self, intptr_t pos) -> nb::typed { return PyType(self.getContext(), mlirTupleTypeGetType(self, pos)) .maybeDownCast(); }, diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index 6ac25e129dacc..761d22357f8f8 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -1024,3 +1024,6 @@ def testVariadicAndNormalRegionOp(): is RegionSequence ) assert type(region_op.variadic) is RegionSequence + + assert isinstance(region_op.opview, OpView) + assert isinstance(region_op.operation.opview, OpView)