Skip to content

Commit 64805ec

Browse files
committed
Implement py3_enum<> mapping to enum.IntEnum
1 parent 7348c40 commit 64805ec

File tree

2 files changed

+104
-0
lines changed

2 files changed

+104
-0
lines changed

include/pybind11/cast.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,77 @@ template <typename... Tuple> class type_caster<std::tuple<Tuple...>> {
890890
std::tuple<make_caster<Tuple>...> value;
891891
};
892892

893+
struct py3_enum_info {
894+
handle type = {};
895+
std::unordered_map<long long, handle> values = {};
896+
897+
py3_enum_info() = default;
898+
899+
py3_enum_info(handle type, const dict& values) : type(type) {
900+
for (auto item : values)
901+
this->values[static_cast<long long>(item.second.cast<int>())] = type.attr(item.first);
902+
}
903+
904+
static std::unordered_map<std::type_index, py3_enum_info>& registry() {
905+
static std::unordered_map<std::type_index, py3_enum_info> map = {};
906+
return map;
907+
}
908+
909+
template<typename T>
910+
static void bind(handle type, const dict& values) {
911+
registry()[typeid(T)] = py3_enum_info(type, values);
912+
}
913+
914+
template<typename T>
915+
static const py3_enum_info* get() {
916+
auto it = registry().find(typeid(T));
917+
return it == registry().end() ? nullptr : &it->second;
918+
}
919+
};
920+
921+
template<typename T>
922+
struct type_caster<T, enable_if_t<std::is_enum<T>::value>> {
923+
private:
924+
using base_caster = type_caster_base<T>;
925+
base_caster caster;
926+
bool py3 = false;
927+
T value;
928+
929+
public:
930+
template<typename U> using cast_op_type = pybind11::detail::cast_op_type<U>;
931+
932+
operator T*() { return py3 ? &value : static_cast<T*>(caster); }
933+
operator T&() { return py3 ? value : static_cast<T&>(caster); }
934+
935+
static handle cast(const T& src, return_value_policy rvp, handle parent) {
936+
if (auto info = py3_enum_info::get<T>()) {
937+
auto it = info->values.find(static_cast<long long>(src));
938+
if (it == info->values.end())
939+
return {};
940+
return it->second.inc_ref();
941+
}
942+
return base_caster::cast(src, rvp, parent);
943+
}
944+
945+
bool load(handle src, bool convert) {
946+
if (!src)
947+
return false;
948+
if (auto info = py3_enum_info::get<T>()) {
949+
py3 = true;
950+
if (!isinstance(src, info->type))
951+
return false;
952+
value = static_cast<T>(src.cast<long long>());
953+
return true;
954+
}
955+
py3 = false;
956+
return caster.load(src, convert);
957+
}
958+
959+
static PYBIND11_DESCR name() {
960+
return base_caster::name();
961+
}
962+
};
963+
893964
/// Helper class which abstracts away certain actions. Users can provide specializations for
894965
/// custom holders, but it's only necessary if the type has a non-standard interface.
895966
template <typename T>

include/pybind11/pybind11.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,39 @@ template <typename Type> class enum_ : public class_<Type> {
12381238
handle m_parent;
12391239
};
12401240

1241+
template<typename T>
1242+
class py3_enum {
1243+
public:
1244+
using underlying_type = typename std::underlying_type<T>::type;
1245+
1246+
py3_enum(handle scope, const char* name)
1247+
: name(name),
1248+
parent(scope),
1249+
ctor(module::import("enum").attr("IntEnum")),
1250+
unique(module::import("enum").attr("unique")) {
1251+
update();
1252+
}
1253+
1254+
py3_enum& value(const char* name, T value) {
1255+
entries[name] = cast(static_cast<underlying_type>(value));
1256+
update();
1257+
return *this;
1258+
}
1259+
1260+
private:
1261+
const char *name;
1262+
handle parent;
1263+
dict entries;
1264+
object ctor;
1265+
object unique;
1266+
1267+
void update() {
1268+
object type = unique(ctor(name, entries));
1269+
setattr(parent, name, type);
1270+
detail::py3_enum_info::bind<T>(type, entries);
1271+
}
1272+
};
1273+
12411274
NAMESPACE_BEGIN(detail)
12421275
template <typename... Args> struct init {
12431276
template <typename Class, typename... Extra, enable_if_t<!Class::has_alias, int> = 0>

0 commit comments

Comments
 (0)