Skip to content

Commit d7beef9

Browse files
committed
refactor: Change mlirApiObjectToCapsule to return std::optional
Instead of raising a `nanobind::type_error()`. This is necessary to honor the nanobind type caster API contract, which requires `from_python` and `from_cpp` methods to be marked `noexcept`.
1 parent 35d5b50 commit d7beef9

File tree

2 files changed

+56
-38
lines changed

2 files changed

+56
-38
lines changed

mlir/include/mlir/Bindings/Python/NanobindAdaptors.h

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
2121

2222
#include <cstdint>
23+
#include <optional>
2324

2425
#include "mlir-c/Diagnostics.h"
2526
#include "mlir-c/IR.h"
@@ -43,18 +44,14 @@ namespace detail {
4344
/// with a raw handle (unowned). The returned object's lifetime may not extend
4445
/// beyond the apiObject handle without explicitly having its refcount increased
4546
/// (i.e. on return).
46-
static nanobind::object mlirApiObjectToCapsule(nanobind::handle apiObject) {
47+
static std::optional<nanobind::object>
48+
mlirApiObjectToCapsule(nanobind::handle apiObject) {
4749
if (PyCapsule_CheckExact(apiObject.ptr()))
4850
return nanobind::borrow<nanobind::object>(apiObject);
4951
nanobind::object api =
5052
nanobind::getattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR, nanobind::none());
51-
if (api.is_none()) {
52-
std::string repr = nanobind::cast<std::string>(nanobind::repr(apiObject));
53-
throw nanobind::type_error(
54-
(llvm::Twine("Expected an MLIR object (got ") + repr + ").")
55-
.str()
56-
.c_str());
57-
}
53+
if (api.is_none())
54+
return std::nullopt;
5855
return api;
5956
}
6057

@@ -68,11 +65,10 @@ template <>
6865
struct type_caster<MlirAffineMap> {
6966
NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap"))
7067
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
71-
nanobind::object capsule = mlirApiObjectToCapsule(src);
72-
value = mlirPythonCapsuleToAffineMap(capsule.ptr());
73-
if (mlirAffineMapIsNull(value)) {
68+
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
69+
if (!capsule)
7470
return false;
75-
}
71+
value = mlirPythonCapsuleToAffineMap(capsule.value().ptr());
7672
return !mlirAffineMapIsNull(value);
7773
}
7874
static handle from_cpp(MlirAffineMap v, rv_policy,
@@ -91,8 +87,10 @@ template <>
9187
struct type_caster<MlirAttribute> {
9288
NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute"))
9389
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
94-
nanobind::object capsule = mlirApiObjectToCapsule(src);
95-
value = mlirPythonCapsuleToAttribute(capsule.ptr());
90+
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
91+
if (!capsule)
92+
return false;
93+
value = mlirPythonCapsuleToAttribute(capsule.value().ptr());
9694
return !mlirAttributeIsNull(value);
9795
}
9896
static handle from_cpp(MlirAttribute v, rv_policy,
@@ -112,8 +110,10 @@ template <>
112110
struct type_caster<MlirBlock> {
113111
NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock"))
114112
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
115-
nanobind::object capsule = mlirApiObjectToCapsule(src);
116-
value = mlirPythonCapsuleToBlock(capsule.ptr());
113+
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
114+
if (!capsule)
115+
return false;
116+
value = mlirPythonCapsuleToBlock(capsule.value().ptr());
117117
return !mlirBlockIsNull(value);
118118
}
119119
};
@@ -132,8 +132,8 @@ struct type_caster<MlirContext> {
132132
.attr("Context")
133133
.attr("current");
134134
}
135-
nanobind::object capsule = mlirApiObjectToCapsule(src);
136-
value = mlirPythonCapsuleToContext(capsule.ptr());
135+
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
136+
value = mlirPythonCapsuleToContext(capsule.value().ptr());
137137
return !mlirContextIsNull(value);
138138
}
139139
};
@@ -143,8 +143,10 @@ template <>
143143
struct type_caster<MlirDialectRegistry> {
144144
NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry"))
145145
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
146-
nanobind::object capsule = mlirApiObjectToCapsule(src);
147-
value = mlirPythonCapsuleToDialectRegistry(capsule.ptr());
146+
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
147+
if (!capsule)
148+
return false;
149+
value = mlirPythonCapsuleToDialectRegistry(capsule.value().ptr());
148150
return !mlirDialectRegistryIsNull(value);
149151
}
150152
static handle from_cpp(MlirDialectRegistry v, rv_policy,
@@ -169,8 +171,8 @@ struct type_caster<MlirLocation> {
169171
.attr("Location")
170172
.attr("current");
171173
}
172-
nanobind::object capsule = mlirApiObjectToCapsule(src);
173-
value = mlirPythonCapsuleToLocation(capsule.ptr());
174+
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
175+
value = mlirPythonCapsuleToLocation(capsule.value().ptr());
174176
return !mlirLocationIsNull(value);
175177
}
176178
static handle from_cpp(MlirLocation v, rv_policy,
@@ -189,8 +191,10 @@ template <>
189191
struct type_caster<MlirModule> {
190192
NB_TYPE_CASTER(MlirModule, const_name("MlirModule"))
191193
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
192-
nanobind::object capsule = mlirApiObjectToCapsule(src);
193-
value = mlirPythonCapsuleToModule(capsule.ptr());
194+
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
195+
if (!capsule)
196+
return false;
197+
value = mlirPythonCapsuleToModule(capsule.value().ptr());
194198
return !mlirModuleIsNull(value);
195199
}
196200
static handle from_cpp(MlirModule v, rv_policy,
@@ -210,8 +214,10 @@ struct type_caster<MlirFrozenRewritePatternSet> {
210214
NB_TYPE_CASTER(MlirFrozenRewritePatternSet,
211215
const_name("MlirFrozenRewritePatternSet"))
212216
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
213-
nanobind::object capsule = mlirApiObjectToCapsule(src);
214-
value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
217+
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
218+
if (!capsule)
219+
return false;
220+
value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.value().ptr());
215221
return value.ptr != nullptr;
216222
}
217223
static handle from_cpp(MlirFrozenRewritePatternSet v, rv_policy,
@@ -230,8 +236,10 @@ template <>
230236
struct type_caster<MlirOperation> {
231237
NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation"))
232238
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
233-
nanobind::object capsule = mlirApiObjectToCapsule(src);
234-
value = mlirPythonCapsuleToOperation(capsule.ptr());
239+
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
240+
if (!capsule)
241+
return false;
242+
value = mlirPythonCapsuleToOperation(capsule.value().ptr());
235243
return !mlirOperationIsNull(value);
236244
}
237245
static handle from_cpp(MlirOperation v, rv_policy,
@@ -252,8 +260,10 @@ template <>
252260
struct type_caster<MlirValue> {
253261
NB_TYPE_CASTER(MlirValue, const_name("MlirValue"))
254262
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
255-
nanobind::object capsule = mlirApiObjectToCapsule(src);
256-
value = mlirPythonCapsuleToValue(capsule.ptr());
263+
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
264+
if (!capsule)
265+
return false;
266+
value = mlirPythonCapsuleToValue(capsule.value().ptr());
257267
return !mlirValueIsNull(value);
258268
}
259269
static handle from_cpp(MlirValue v, rv_policy,
@@ -275,8 +285,10 @@ template <>
275285
struct type_caster<MlirPassManager> {
276286
NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager"))
277287
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
278-
nanobind::object capsule = mlirApiObjectToCapsule(src);
279-
value = mlirPythonCapsuleToPassManager(capsule.ptr());
288+
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
289+
if (!capsule)
290+
return false;
291+
value = mlirPythonCapsuleToPassManager(capsule.value().ptr());
280292
return !mlirPassManagerIsNull(value);
281293
}
282294
};
@@ -286,8 +298,10 @@ template <>
286298
struct type_caster<MlirTypeID> {
287299
NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID"))
288300
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
289-
nanobind::object capsule = mlirApiObjectToCapsule(src);
290-
value = mlirPythonCapsuleToTypeID(capsule.ptr());
301+
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
302+
if (!capsule)
303+
return false;
304+
value = mlirPythonCapsuleToTypeID(capsule.value().ptr());
291305
return !mlirTypeIDIsNull(value);
292306
}
293307
static handle from_cpp(MlirTypeID v, rv_policy,
@@ -308,8 +322,10 @@ template <>
308322
struct type_caster<MlirType> {
309323
NB_TYPE_CASTER(MlirType, const_name("MlirType"))
310324
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
311-
nanobind::object capsule = mlirApiObjectToCapsule(src);
312-
value = mlirPythonCapsuleToType(capsule.ptr());
325+
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
326+
if (!capsule)
327+
return false;
328+
value = mlirPythonCapsuleToType(capsule.value().ptr());
313329
return !mlirTypeIsNull(value);
314330
}
315331
static handle from_cpp(MlirType t, rv_policy,

mlir/test/python/lib/PythonTestModuleNanobind.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,10 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
113113
.attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
114114
mlirRankedTensorTypeID)(
115115
nanobind::cpp_function([valueCls](const nb::object &valueObj) {
116-
nb::object capsule = mlirApiObjectToCapsule(valueObj);
117-
MlirValue v = mlirPythonCapsuleToValue(capsule.ptr());
116+
std::optional<nb::object> capsule =
117+
mlirApiObjectToCapsule(valueObj);
118+
// TODO(nicholasjng): Can this capsule be std::nullopt?
119+
MlirValue v = mlirPythonCapsuleToValue(capsule.value().ptr());
118120
MlirType t = mlirValueGetType(v);
119121
// This is hyper-specific in order to exercise/test registering a
120122
// value caster from cpp (but only for a single test case; see

0 commit comments

Comments
 (0)