1+
2+ /*
3+ pybind11/eigen_tensor.h: Transparent conversion for Eigen tensors
4+
5+ Copyright (c) 2016 Wenzel Jakob <[email protected] > 6+
7+ All rights reserved. Use of this source code is governed by a
8+ BSD-style license that can be found in the LICENSE file.
9+ */
10+
11+ #pragma once
12+
13+ /* HINT: To suppress warnings originating from the Eigen headers, use -isystem.
14+ See also:
15+ https://stackoverflow.com/questions/2579576/i-dir-vs-isystem-dir
16+ https://stackoverflow.com/questions/1741816/isystem-for-ms-visual-studio-c-compiler
17+ */
18+
19+ #include " ../numpy.h"
20+
21+ // The C4127 suppression was introduced for Eigen 3.4.0. In theory we could
22+ // make it version specific, or even remove it later, but considering that
23+ // 1. C4127 is generally far more distracting than useful for modern template code, and
24+ // 2. we definitely want to ignore any MSVC warnings originating from Eigen code,
25+ // it is probably best to keep this around indefinitely.
26+ #if defined(_MSC_VER)
27+ # pragma warning(push)
28+ # pragma warning(disable : 4554) // Tensor.h warning
29+ // C5054: operator '&': deprecated between enumerations of different types
30+ #elif defined(__MINGW32__)
31+ # pragma GCC diagnostic push
32+ # pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
33+ #endif
34+
35+ #include < unsupported/Eigen/CXX11/Tensor>
36+
37+ #if defined(_MSC_VER)
38+ # pragma warning(pop)
39+ #elif defined(__MINGW32__)
40+ # pragma GCC diagnostic pop
41+ #endif
42+
43+ PYBIND11_NAMESPACE_BEGIN (PYBIND11_NAMESPACE)
44+
45+ PYBIND11_NAMESPACE_BEGIN(detail)
46+
47+ template <typename T>
48+ constexpr int compute_array_flag_from_tensor() {
49+ static_assert (((int ) T::Layout == (int ) Eigen::RowMajor)
50+ || ((int ) T::Layout == (int ) Eigen::ColMajor),
51+ " Layout must be row or column major" );
52+ return ((int ) T::Layout == (int ) Eigen::RowMajor) ? array::c_style : array::f_style;
53+ }
54+
55+ template <typename T>
56+ struct eigen_tensor_helper {};
57+
58+ template <typename Scalar_, int NumIndices_, int Options_, typename IndexType>
59+ struct eigen_tensor_helper <Eigen::Tensor<Scalar_, NumIndices_, Options_, IndexType>> {
60+ using T = Eigen::Tensor<Scalar_, NumIndices_, Options_, IndexType>;
61+ using ValidType = void ;
62+
63+ static std::array<typename T::Index, T::NumIndices> get_shape (const T &f) {
64+ return f.dimensions ();
65+ }
66+
67+ static constexpr bool
68+ is_correct_shape (const std::array<typename T::Index, T::NumIndices> & /* shape*/ ) {
69+ return true ;
70+ }
71+
72+ template <typename T>
73+ struct helper {};
74+
75+ template <size_t ... Is>
76+ struct helper <index_sequence<Is...>> {
77+ static constexpr auto value = concat(const_name(((void ) Is, " ?" ))...);
78+ };
79+
80+ static constexpr auto dimensions_descriptor
81+ = helper<decltype (make_index_sequence<T::NumIndices>())>::value;
82+ };
83+
84+ template <typename Scalar_, typename std::ptrdiff_t ... Indices, int Options_, typename IndexType>
85+ struct eigen_tensor_helper <
86+ Eigen::TensorFixedSize<Scalar_, Eigen::Sizes<Indices...>, Options_, IndexType>> {
87+ using T = Eigen::TensorFixedSize<Scalar_, Eigen::Sizes<Indices...>, Options_, IndexType>;
88+ using ValidType = void ;
89+
90+ static constexpr std::array<typename T::Index, T::NumIndices> get_shape (const T & /* f*/ ) {
91+ return get_shape ();
92+ }
93+
94+ static constexpr std::array<typename T::Index, T::NumIndices> get_shape () {
95+ return {{Indices...}};
96+ }
97+
98+ static bool is_correct_shape (const std::array<typename T::Index, T::NumIndices> &shape) {
99+ return get_shape () == shape;
100+ }
101+
102+ static constexpr auto dimensions_descriptor = concat(const_name<Indices>()...);
103+ };
104+
105+ template <typename T>
106+ struct get_tensor_descriptor {
107+ static constexpr auto value
108+ = const_name(" numpy.ndarray[" ) + npy_format_descriptor<typename T::Scalar>::name
109+ + const_name(" [" ) + eigen_tensor_helper<T>::dimensions_descriptor
110+ + const_name(" ], flags.writeable, " )
111+ + const_name<(int ) T::Layout == (int ) Eigen::RowMajor>(" flags.c_contiguous" ,
112+ " flags.f_contiguous" );
113+ };
114+
115+ template <typename Type>
116+ struct type_caster <Type, typename eigen_tensor_helper<Type>::ValidType> {
117+ using H = eigen_tensor_helper<Type>;
118+ PYBIND11_TYPE_CASTER (Type, get_tensor_descriptor<Type>::value);
119+
120+ bool load (handle src, bool /* convert*/ ) {
121+ array_t <typename Type::Scalar, compute_array_flag_from_tensor<Type>()> a (
122+ reinterpret_borrow<object>(src));
123+
124+ if (a.ndim () != Type::NumIndices) {
125+ return false ;
126+ }
127+
128+ std::array<typename Type::Index, Type::NumIndices> shape;
129+ std::copy (a.shape (), a.shape () + Type::NumIndices, shape.begin ());
130+
131+ if (!H::is_correct_shape (shape)) {
132+ return false ;
133+ }
134+
135+ value = Eigen::TensorMap<Type>(const_cast <typename Type::Scalar *>(a.data ()), shape);
136+
137+ return true ;
138+ }
139+
140+ static handle cast (Type &&src, return_value_policy policy, handle parent) {
141+ if (policy == return_value_policy::reference
142+ || policy == return_value_policy::reference_internal) {
143+ pybind11_fail (" Cannot use a reference return value policy for an rvalue" );
144+ }
145+ return cast_impl (&src, return_value_policy::move, parent);
146+ }
147+
148+ static handle cast (const Type &&src, return_value_policy policy, handle parent) {
149+ if (policy == return_value_policy::reference
150+ || policy == return_value_policy::reference_internal) {
151+ pybind11_fail (" Cannot use a reference return value policy for an rvalue" );
152+ }
153+ return cast_impl (&src, return_value_policy::move, parent);
154+ }
155+
156+ static handle cast (Type &src, return_value_policy policy, handle parent) {
157+ if (policy == return_value_policy::automatic
158+ || policy == return_value_policy::automatic_reference) {
159+ policy = return_value_policy::copy;
160+ }
161+ return cast_impl (&src, policy, parent);
162+ }
163+
164+ static handle cast (const Type &src, return_value_policy policy, handle parent) {
165+ if (policy == return_value_policy::automatic
166+ || policy == return_value_policy::automatic_reference) {
167+ policy = return_value_policy::copy;
168+ }
169+ return cast (&src, policy, parent);
170+ }
171+
172+ static handle cast (Type *src, return_value_policy policy, handle parent) {
173+ if (policy == return_value_policy::automatic) {
174+ policy = return_value_policy::take_ownership;
175+ } else if (policy == return_value_policy::automatic_reference) {
176+ policy = return_value_policy::reference;
177+ }
178+ return cast_impl (src, policy, parent);
179+ }
180+
181+ static handle cast (const Type *src, return_value_policy policy, handle parent) {
182+ if (policy == return_value_policy::automatic) {
183+ policy = return_value_policy::take_ownership;
184+ } else if (policy == return_value_policy::automatic_reference) {
185+ policy = return_value_policy::reference;
186+ }
187+ return cast_impl (src, policy, parent);
188+ }
189+
190+ template <typename C>
191+ static handle cast_impl (C *src, return_value_policy policy, handle parent) {
192+ object parent_object;
193+ bool writeable = false ;
194+ switch (policy) {
195+ case return_value_policy::move:
196+ if (std::is_const<C>::value) {
197+ pybind11_fail (" Cannot move from a constant reference" );
198+ }
199+ {
200+ Eigen::aligned_allocator<Type> allocator;
201+ Type *copy = ::new (allocator.allocate (1 )) Type (std::move (*src));
202+ src = copy;
203+ }
204+
205+ parent_object = capsule (src, [](void *ptr) {
206+ Eigen::aligned_allocator<Type> allocator;
207+ Type *copy = (Type *) ptr;
208+ copy->~Type ();
209+ allocator.deallocate (copy, 1 );
210+ });
211+ writeable = true ;
212+ break ;
213+
214+ case return_value_policy::take_ownership:
215+ if (std::is_const<C>::value) {
216+ pybind11_fail (" Cannot take ownership of a const reference" );
217+ }
218+ parent_object = capsule (src, [](void *ptr) { delete (Type *) ptr; });
219+ writeable = true ;
220+ break ;
221+
222+ case return_value_policy::copy:
223+ parent_object = {};
224+ writeable = true ;
225+ break ;
226+
227+ case return_value_policy::reference:
228+ parent_object = none ();
229+ writeable = !std::is_const<C>::value;
230+ break ;
231+
232+ case return_value_policy::reference_internal:
233+ // Default should do the right thing
234+ parent_object = reinterpret_borrow<object>(parent);
235+ writeable = !std::is_const<C>::value;
236+ break ;
237+
238+ default :
239+ pybind11_fail (" pybind11 bug in eigen.h, please file a bug report" );
240+ }
241+
242+ handle result = array_t <typename Type::Scalar, compute_array_flag_from_tensor<Type>()>(
243+ H::get_shape (*src), src->data (), parent_object)
244+ .release ();
245+
246+ if (!writeable) {
247+ array_proxy (result.ptr ())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
248+ }
249+
250+ return result;
251+ }
252+ };
253+
254+ template <typename Type>
255+ struct type_caster <Eigen::TensorMap<Type>, typename eigen_tensor_helper<Type>::ValidType> {
256+ using H = eigen_tensor_helper<Type>;
257+
258+ bool load (handle src, bool /* convert*/ ) {
259+ // Note that we have a lot more checks here as we want to make sure to avoid copies
260+ auto a = reinterpret_borrow<array>(src);
261+ if ((a.flags () & compute_array_flag_from_tensor<Type>()) == 0 ) {
262+ return false ;
263+ }
264+
265+ if (!a.dtype ().is (dtype::of<typename Type::Scalar>())) {
266+ return false ;
267+ }
268+
269+ if (a.ndim () != Type::NumIndices) {
270+ return false ;
271+ }
272+
273+ std::array<typename Type::Index, Type::NumIndices> shape;
274+ std::copy (a.shape (), a.shape () + Type::NumIndices, shape.begin ());
275+
276+ if (!H::is_correct_shape (shape)) {
277+ return false ;
278+ }
279+
280+ value.reset (new Eigen::TensorMap<Type>(
281+ reinterpret_cast <typename Type::Scalar *>(a.mutable_data ()), shape));
282+
283+ return true ;
284+ }
285+
286+ static handle cast (Eigen::TensorMap<Type> &&src, return_value_policy policy, handle parent) {
287+ return cast_impl (&src, policy, parent);
288+ }
289+
290+ static handle
291+ cast (const Eigen::TensorMap<Type> &&src, return_value_policy policy, handle parent) {
292+ return cast_impl (&src, policy, parent);
293+ }
294+
295+ static handle cast (Eigen::TensorMap<Type> &src, return_value_policy policy, handle parent) {
296+ if (policy == return_value_policy::automatic
297+ || policy == return_value_policy::automatic_reference) {
298+ policy = return_value_policy::copy;
299+ }
300+ return cast_impl (&src, policy, parent);
301+ }
302+
303+ static handle
304+ cast (const Eigen::TensorMap<Type> &src, return_value_policy policy, handle parent) {
305+ if (policy == return_value_policy::automatic
306+ || policy == return_value_policy::automatic_reference) {
307+ policy = return_value_policy::copy;
308+ }
309+ return cast (&src, policy, parent);
310+ }
311+
312+ static handle cast (Eigen::TensorMap<Type> *src, return_value_policy policy, handle parent) {
313+ if (policy == return_value_policy::automatic) {
314+ policy = return_value_policy::take_ownership;
315+ } else if (policy == return_value_policy::automatic_reference) {
316+ policy = return_value_policy::reference;
317+ }
318+ return cast_impl (src, policy, parent);
319+ }
320+
321+ static handle
322+ cast (const Eigen::TensorMap<Type> *src, return_value_policy policy, handle parent) {
323+ if (policy == return_value_policy::automatic) {
324+ policy = return_value_policy::take_ownership;
325+ } else if (policy == return_value_policy::automatic_reference) {
326+ policy = return_value_policy::reference;
327+ }
328+ return cast_impl (src, policy, parent);
329+ }
330+
331+ template <typename C>
332+ static handle cast_impl (C *src, return_value_policy policy, handle parent) {
333+ object parent_object;
334+ constexpr bool writeable = !std::is_const<C>::value;
335+ switch (policy) {
336+ case return_value_policy::reference:
337+ parent_object = none ();
338+ break ;
339+
340+ case return_value_policy::reference_internal:
341+ // Default should do the right thing
342+ parent_object = reinterpret_borrow<object>(parent);
343+ break ;
344+
345+ default :
346+ // move, take_ownership don't make any sense for a ref/map:
347+ pybind11_fail (" Invalid return_value_policy for Eigen Map type, must be either "
348+ " reference or reference_internal" );
349+ }
350+
351+ handle result = array_t <typename Type::Scalar, compute_array_flag_from_tensor<Type>()>(
352+ H::get_shape (*src), src->data (), parent_object)
353+ .release ();
354+
355+ if (!writeable) {
356+ array_proxy (result.ptr ())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
357+ }
358+
359+ return result;
360+ }
361+
362+ protected:
363+ // TODO: Move to std::optional once std::optional has more support
364+ std::unique_ptr<Eigen::TensorMap<Type>> value;
365+
366+ public:
367+ static constexpr auto name = get_tensor_descriptor<Type>::value;
368+ explicit operator Eigen::TensorMap<Type> *() {
369+ return value.get ();
370+ } /* NOLINT(bugprone-macro-parentheses) */
371+ explicit operator Eigen::TensorMap<Type> &() {
372+ return *value;
373+ } /* NOLINT(bugprone-macro-parentheses) */
374+ explicit operator Eigen::TensorMap<Type> &&() && {
375+ return std::move (*value);
376+ } /* NOLINT(bugprone-macro-parentheses) */
377+
378+ template <typename T_>
379+ using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>;
380+ };
381+
382+ PYBIND11_NAMESPACE_END (detail)
383+ PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
0 commit comments