Skip to content

Commit 0d08ffd

Browse files
authored
[MLIR][Python] use nb::typed for return signatures (#160221)
#160183 removed `nb::typed` annotation to fix bazel but it turned out to be simply a matter of not using the correct version of nanobind (see #160183 (comment)). This PR restores those annotations but (mostly) moves to the return positions of the actual methods.
1 parent 4feb092 commit 0d08ffd

File tree

7 files changed

+214
-149
lines changed

7 files changed

+214
-149
lines changed

mlir/lib/Bindings/Python/IRAffine.cpp

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,9 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
574574
})
575575
.def_prop_ro(
576576
"context",
577-
[](PyAffineExpr &self) { return self.getContext().getObject(); })
577+
[](PyAffineExpr &self) -> nb::typed<nb::object, PyMlirContext> {
578+
return self.getContext().getObject();
579+
})
578580
.def("compose",
579581
[](PyAffineExpr &self, PyAffineMap &other) {
580582
return PyAffineExpr(self.getContext(),
@@ -706,28 +708,29 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
706708
[](PyAffineMap &self) {
707709
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
708710
})
709-
.def_static("compress_unused_symbols",
710-
[](const nb::list &affineMaps,
711-
DefaultingPyMlirContext context) {
712-
SmallVector<MlirAffineMap> maps;
713-
pyListToVector<PyAffineMap, MlirAffineMap>(
714-
affineMaps, maps, "attempting to create an AffineMap");
715-
std::vector<MlirAffineMap> compressed(affineMaps.size());
716-
auto populate = [](void *result, intptr_t idx,
717-
MlirAffineMap m) {
718-
static_cast<MlirAffineMap *>(result)[idx] = (m);
719-
};
720-
mlirAffineMapCompressUnusedSymbols(
721-
maps.data(), maps.size(), compressed.data(), populate);
722-
std::vector<PyAffineMap> res;
723-
res.reserve(compressed.size());
724-
for (auto m : compressed)
725-
res.emplace_back(context->getRef(), m);
726-
return res;
727-
})
711+
.def_static(
712+
"compress_unused_symbols",
713+
[](const nb::list &affineMaps, DefaultingPyMlirContext context) {
714+
SmallVector<MlirAffineMap> maps;
715+
pyListToVector<PyAffineMap, MlirAffineMap>(
716+
affineMaps, maps, "attempting to create an AffineMap");
717+
std::vector<MlirAffineMap> compressed(affineMaps.size());
718+
auto populate = [](void *result, intptr_t idx, MlirAffineMap m) {
719+
static_cast<MlirAffineMap *>(result)[idx] = (m);
720+
};
721+
mlirAffineMapCompressUnusedSymbols(maps.data(), maps.size(),
722+
compressed.data(), populate);
723+
std::vector<PyAffineMap> res;
724+
res.reserve(compressed.size());
725+
for (auto m : compressed)
726+
res.emplace_back(context->getRef(), m);
727+
return res;
728+
})
728729
.def_prop_ro(
729730
"context",
730-
[](PyAffineMap &self) { return self.getContext().getObject(); },
731+
[](PyAffineMap &self) -> nb::typed<nb::object, PyMlirContext> {
732+
return self.getContext().getObject();
733+
},
731734
"Context that owns the Affine Map")
732735
.def(
733736
"dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
@@ -893,7 +896,9 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
893896
})
894897
.def_prop_ro(
895898
"context",
896-
[](PyIntegerSet &self) { return self.getContext().getObject(); })
899+
[](PyIntegerSet &self) -> nb::typed<nb::object, PyMlirContext> {
900+
return self.getContext().getObject();
901+
})
897902
.def(
898903
"dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
899904
kDumpDocstring)

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
485485

486486
PyArrayAttributeIterator &dunderIter() { return *this; }
487487

488-
nb::object dunderNext() {
488+
nb::typed<nb::object, PyAttribute> dunderNext() {
489489
// TODO: Throw is an inefficient way to stop iteration.
490490
if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
491491
throw nb::stop_iteration();
@@ -526,7 +526,8 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
526526
"Gets a uniqued Array attribute");
527527
c.def(
528528
"__getitem__",
529-
[](PyArrayAttribute &arr, intptr_t i) {
529+
[](PyArrayAttribute &arr,
530+
intptr_t i) -> nb::typed<nb::object, PyAttribute> {
530531
if (i >= mlirArrayAttrGetNumElements(arr))
531532
throw nb::index_error("ArrayAttribute index out of range");
532533
return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
@@ -1010,14 +1011,16 @@ class PyDenseElementsAttribute
10101011
[](PyDenseElementsAttribute &self) -> bool {
10111012
return mlirDenseElementsAttrIsSplat(self);
10121013
})
1013-
.def("get_splat_value", [](PyDenseElementsAttribute &self) {
1014-
if (!mlirDenseElementsAttrIsSplat(self))
1015-
throw nb::value_error(
1016-
"get_splat_value called on a non-splat attribute");
1017-
return PyAttribute(self.getContext(),
1018-
mlirDenseElementsAttrGetSplatValue(self))
1019-
.maybeDownCast();
1020-
});
1014+
.def("get_splat_value",
1015+
[](PyDenseElementsAttribute &self)
1016+
-> nb::typed<nb::object, PyAttribute> {
1017+
if (!mlirDenseElementsAttrIsSplat(self))
1018+
throw nb::value_error(
1019+
"get_splat_value called on a non-splat attribute");
1020+
return PyAttribute(self.getContext(),
1021+
mlirDenseElementsAttrGetSplatValue(self))
1022+
.maybeDownCast();
1023+
});
10211024
}
10221025

10231026
static PyType_Slot slots[];
@@ -1332,7 +1335,7 @@ class PyDenseIntElementsAttribute
13321335

13331336
/// Returns the element at the given linear position. Asserts if the index
13341337
/// is out of range.
1335-
nb::object dunderGetItem(intptr_t pos) {
1338+
nb::int_ dunderGetItem(intptr_t pos) {
13361339
if (pos < 0 || pos >= dunderLen()) {
13371340
throw nb::index_error("attempt to access out of bounds element");
13381341
}
@@ -1522,13 +1525,15 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
15221525
},
15231526
nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(),
15241527
"Gets an uniqued dict attribute");
1525-
c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
1526-
MlirAttribute attr =
1527-
mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
1528-
if (mlirAttributeIsNull(attr))
1529-
throw nb::key_error("attempt to access a non-existent attribute");
1530-
return PyAttribute(self.getContext(), attr).maybeDownCast();
1531-
});
1528+
c.def("__getitem__",
1529+
[](PyDictAttribute &self,
1530+
const std::string &name) -> nb::typed<nb::object, PyAttribute> {
1531+
MlirAttribute attr =
1532+
mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
1533+
if (mlirAttributeIsNull(attr))
1534+
throw nb::key_error("attempt to access a non-existent attribute");
1535+
return PyAttribute(self.getContext(), attr).maybeDownCast();
1536+
});
15321537
c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
15331538
if (index < 0 || index >= self.dunderLen()) {
15341539
throw nb::index_error("attempt to access out of bounds attribute");
@@ -1594,10 +1599,11 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
15941599
},
15951600
nb::arg("value"), nb::arg("context") = nb::none(),
15961601
"Gets a uniqued Type attribute");
1597-
c.def_prop_ro("value", [](PyTypeAttribute &self) {
1598-
return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
1599-
.maybeDownCast();
1600-
});
1602+
c.def_prop_ro(
1603+
"value", [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
1604+
return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
1605+
.maybeDownCast();
1606+
});
16011607
}
16021608
};
16031609

0 commit comments

Comments
 (0)