Skip to content

Commit d0c4d2f

Browse files
committed
refactor: Assign in if statements everywhere, adjust test expectations
Credit goes to @makslevental, his patch can be found in https://gist.github.com/makslevental/b224ffca7f15e273a4897975cda28b4c.
1 parent d7beef9 commit d0c4d2f

File tree

3 files changed

+33
-51
lines changed

3 files changed

+33
-51
lines changed

mlir/include/mlir/Bindings/Python/NanobindAdaptors.h

Lines changed: 26 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ mlirApiObjectToCapsule(nanobind::handle apiObject) {
5151
nanobind::object api =
5252
nanobind::getattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR, nanobind::none());
5353
if (api.is_none())
54-
return std::nullopt;
54+
return {};
5555
return api;
5656
}
5757

@@ -65,10 +65,8 @@ template <>
6565
struct type_caster<MlirAffineMap> {
6666
NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap"))
6767
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
68-
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
69-
if (!capsule)
70-
return false;
71-
value = mlirPythonCapsuleToAffineMap(capsule.value().ptr());
68+
if (auto capsule = mlirApiObjectToCapsule(src))
69+
value = mlirPythonCapsuleToAffineMap(capsule->ptr());
7270
return !mlirAffineMapIsNull(value);
7371
}
7472
static handle from_cpp(MlirAffineMap v, rv_policy,
@@ -87,10 +85,8 @@ template <>
8785
struct type_caster<MlirAttribute> {
8886
NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute"))
8987
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
90-
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
91-
if (!capsule)
92-
return false;
93-
value = mlirPythonCapsuleToAttribute(capsule.value().ptr());
88+
if (auto capsule = mlirApiObjectToCapsule(src))
89+
value = mlirPythonCapsuleToAttribute(capsule->ptr());
9490
return !mlirAttributeIsNull(value);
9591
}
9692
static handle from_cpp(MlirAttribute v, rv_policy,
@@ -110,10 +106,8 @@ template <>
110106
struct type_caster<MlirBlock> {
111107
NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock"))
112108
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
113-
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
114-
if (!capsule)
115-
return false;
116-
value = mlirPythonCapsuleToBlock(capsule.value().ptr());
109+
if (auto capsule = mlirApiObjectToCapsule(src))
110+
value = mlirPythonCapsuleToBlock(capsule->ptr());
117111
return !mlirBlockIsNull(value);
118112
}
119113
};
@@ -133,7 +127,7 @@ struct type_caster<MlirContext> {
133127
.attr("current");
134128
}
135129
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
136-
value = mlirPythonCapsuleToContext(capsule.value().ptr());
130+
value = mlirPythonCapsuleToContext(capsule->ptr());
137131
return !mlirContextIsNull(value);
138132
}
139133
};
@@ -143,10 +137,8 @@ template <>
143137
struct type_caster<MlirDialectRegistry> {
144138
NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry"))
145139
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
146-
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
147-
if (!capsule)
148-
return false;
149-
value = mlirPythonCapsuleToDialectRegistry(capsule.value().ptr());
140+
if (auto capsule = mlirApiObjectToCapsule(src))
141+
value = mlirPythonCapsuleToDialectRegistry(capsule->ptr());
150142
return !mlirDialectRegistryIsNull(value);
151143
}
152144
static handle from_cpp(MlirDialectRegistry v, rv_policy,
@@ -171,8 +163,8 @@ struct type_caster<MlirLocation> {
171163
.attr("Location")
172164
.attr("current");
173165
}
174-
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
175-
value = mlirPythonCapsuleToLocation(capsule.value().ptr());
166+
if (auto capsule = mlirApiObjectToCapsule(src))
167+
value = mlirPythonCapsuleToLocation(capsule->ptr());
176168
return !mlirLocationIsNull(value);
177169
}
178170
static handle from_cpp(MlirLocation v, rv_policy,
@@ -191,10 +183,8 @@ template <>
191183
struct type_caster<MlirModule> {
192184
NB_TYPE_CASTER(MlirModule, const_name("MlirModule"))
193185
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
194-
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
195-
if (!capsule)
196-
return false;
197-
value = mlirPythonCapsuleToModule(capsule.value().ptr());
186+
if (auto capsule = mlirApiObjectToCapsule(src))
187+
value = mlirPythonCapsuleToModule(capsule->ptr());
198188
return !mlirModuleIsNull(value);
199189
}
200190
static handle from_cpp(MlirModule v, rv_policy,
@@ -214,10 +204,8 @@ struct type_caster<MlirFrozenRewritePatternSet> {
214204
NB_TYPE_CASTER(MlirFrozenRewritePatternSet,
215205
const_name("MlirFrozenRewritePatternSet"))
216206
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
217-
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
218-
if (!capsule)
219-
return false;
220-
value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.value().ptr());
207+
if (auto capsule = mlirApiObjectToCapsule(src))
208+
value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule->ptr());
221209
return value.ptr != nullptr;
222210
}
223211
static handle from_cpp(MlirFrozenRewritePatternSet v, rv_policy,
@@ -236,10 +224,8 @@ template <>
236224
struct type_caster<MlirOperation> {
237225
NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation"))
238226
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
239-
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
240-
if (!capsule)
241-
return false;
242-
value = mlirPythonCapsuleToOperation(capsule.value().ptr());
227+
if (auto capsule = mlirApiObjectToCapsule(src))
228+
value = mlirPythonCapsuleToOperation(capsule->ptr());
243229
return !mlirOperationIsNull(value);
244230
}
245231
static handle from_cpp(MlirOperation v, rv_policy,
@@ -260,10 +246,8 @@ template <>
260246
struct type_caster<MlirValue> {
261247
NB_TYPE_CASTER(MlirValue, const_name("MlirValue"))
262248
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
263-
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
264-
if (!capsule)
265-
return false;
266-
value = mlirPythonCapsuleToValue(capsule.value().ptr());
249+
if (auto capsule = mlirApiObjectToCapsule(src))
250+
value = mlirPythonCapsuleToValue(capsule->ptr());
267251
return !mlirValueIsNull(value);
268252
}
269253
static handle from_cpp(MlirValue v, rv_policy,
@@ -285,10 +269,8 @@ template <>
285269
struct type_caster<MlirPassManager> {
286270
NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager"))
287271
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
288-
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
289-
if (!capsule)
290-
return false;
291-
value = mlirPythonCapsuleToPassManager(capsule.value().ptr());
272+
if (auto capsule = mlirApiObjectToCapsule(src))
273+
value = mlirPythonCapsuleToPassManager(capsule->ptr());
292274
return !mlirPassManagerIsNull(value);
293275
}
294276
};
@@ -298,10 +280,8 @@ template <>
298280
struct type_caster<MlirTypeID> {
299281
NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID"))
300282
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
301-
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
302-
if (!capsule)
303-
return false;
304-
value = mlirPythonCapsuleToTypeID(capsule.value().ptr());
283+
if (auto capsule = mlirApiObjectToCapsule(src))
284+
value = mlirPythonCapsuleToTypeID(capsule->ptr());
305285
return !mlirTypeIDIsNull(value);
306286
}
307287
static handle from_cpp(MlirTypeID v, rv_policy,
@@ -322,10 +302,8 @@ template <>
322302
struct type_caster<MlirType> {
323303
NB_TYPE_CASTER(MlirType, const_name("MlirType"))
324304
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
325-
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
326-
if (!capsule)
327-
return false;
328-
value = mlirPythonCapsuleToType(capsule.value().ptr());
305+
if (auto capsule = mlirApiObjectToCapsule(src))
306+
value = mlirPythonCapsuleToType(capsule->ptr());
329307
return !mlirTypeIsNull(value);
330308
}
331309
static handle from_cpp(MlirType t, rv_policy,

mlir/test/python/dialects/python_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,9 @@ def testCustomAttribute():
361361
try:
362362
TestAttr(42)
363363
except TypeError as e:
364-
assert "Expected an MLIR object" in str(e)
364+
assert "Expected an MLIR object (got 42)" in str(e)
365+
except ValueError as e:
366+
assert "Cannot cast attribute to TestAttr (from 42)" in str(e)
365367
else:
366368
raise
367369

@@ -406,7 +408,9 @@ def testCustomType():
406408
try:
407409
TestType(42)
408410
except TypeError as e:
409-
assert "Expected an MLIR object" in str(e)
411+
assert "Expected an MLIR object (got 42)" in str(e)
412+
except ValueError as e:
413+
assert "Cannot cast type to TestType (from 42)" in str(e)
410414
else:
411415
raise
412416

mlir/test/python/lib/PythonTestModuleNanobind.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
115115
nanobind::cpp_function([valueCls](const nb::object &valueObj) {
116116
std::optional<nb::object> capsule =
117117
mlirApiObjectToCapsule(valueObj);
118-
// TODO(nicholasjng): Can this capsule be std::nullopt?
118+
assert(capsule.has_value() && "capsule is not null");
119119
MlirValue v = mlirPythonCapsuleToValue(capsule.value().ptr());
120120
MlirType t = mlirValueGetType(v);
121121
// This is hyper-specific in order to exercise/test registering a

0 commit comments

Comments
 (0)