Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 53 additions & 6 deletions orc-rt/include/orc-rt/SPSWrapperFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ namespace orc_rt {
namespace detail {

template <typename... SPSArgTs> struct WFSPSHelper {
template <typename... ArgTs>
std::optional<WrapperFunctionBuffer> serialize(const ArgTs &...Args) {
private:
template <typename... SerializableArgTs>
std::optional<WrapperFunctionBuffer>
serializeImpl(const SerializableArgTs &...Args) {
auto R =
WrapperFunctionBuffer::allocate(SPSArgList<SPSArgTs...>::size(Args...));
SPSOutputBuffer OB(R.data(), R.size());
Expand All @@ -31,16 +33,61 @@ template <typename... SPSArgTs> struct WFSPSHelper {
return std::move(R);
}

template <typename T> static const T &toSerializable(const T &Arg) noexcept {
return Arg;
}

static SPSSerializableError toSerializable(Error Err) noexcept {
return SPSSerializableError(std::move(Err));
}

template <typename T>
static SPSSerializableExpected<T> toSerializable(Expected<T> Arg) noexcept {
return SPSSerializableExpected<T>(std::move(Arg));
}

template <typename... Ts> struct DeserializableTuple;

template <typename... Ts> struct DeserializableTuple<std::tuple<Ts...>> {
typedef std::tuple<
std::decay_t<decltype(toSerializable(std::declval<Ts>()))>...>
type;
};

template <typename... Ts>
using DeserializableTuple_t = typename DeserializableTuple<Ts...>::type;

template <typename T> static T fromSerializable(T &&Arg) noexcept {
return Arg;
}

static Error fromSerializable(SPSSerializableError Err) noexcept {
return Err.toError();
}

template <typename T>
static Expected<T> fromSerializable(SPSSerializableExpected<T> Val) noexcept {
return Val.toExpected();
}

public:
template <typename... ArgTs>
std::optional<WrapperFunctionBuffer> serialize(ArgTs &&...Args) {
return serializeImpl(toSerializable(std::forward<ArgTs>(Args))...);
}

template <typename ArgTuple>
std::optional<ArgTuple> deserialize(WrapperFunctionBuffer ArgBytes) {
assert(!ArgBytes.getOutOfBandError() &&
"Should not attempt to deserialize out-of-band error");
SPSInputBuffer IB(ArgBytes.data(), ArgBytes.size());
ArgTuple Args;
if (!SPSSerializationTraits<SPSTuple<SPSArgTs...>, ArgTuple>::deserialize(
IB, Args))
DeserializableTuple_t<ArgTuple> Args;
if (!SPSSerializationTraits<SPSTuple<SPSArgTs...>,
decltype(Args)>::deserialize(IB, Args))
return std::nullopt;
return Args;
return std::apply(
[](auto &&...A) { return ArgTuple(fromSerializable(A)...); },
std::move(Args));
}
};

Expand Down
3 changes: 2 additions & 1 deletion orc-rt/include/orc-rt/WrapperFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ struct ResultDeserializer<std::tuple<Expected<T>>, Serializer> {
Serializer &S) {
if (auto Val = S.result().template deserialize<std::tuple<T>>(
std::move(ResultBytes)))
return std::move(std::get<0>(*Val));
return Expected<T>(std::move(std::get<0>(*Val)),
ForceExpectedSuccessValue());
else
return make_error<StringError>("Could not deserialize result");
}
Expand Down
74 changes: 74 additions & 0 deletions orc-rt/unittests/SPSWrapperFunctionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,77 @@ TEST(SPSWrapperFunctionUtilsTest, TestBinaryOpViaFunctionPointer) {
[&](Expected<int32_t> R) { Result = cantFail(std::move(R)); }, 41, 1);
EXPECT_EQ(Result, 42);
}

static void improbable_feat_sps_wrapper(orc_rt_SessionRef Session,
void *CallCtx,
orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes) {
SPSWrapperFunction<SPSError(bool)>::handle(
Session, CallCtx, Return, ArgBytes,
[](move_only_function<void(Error)> Return, bool LuckyHat) {
if (LuckyHat)
Return(Error::success());
else
Return(make_error<StringError>("crushed by boulder"));
});
}

TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningErrorSuccessCase) {
bool DidRun = false;
SPSWrapperFunction<SPSError(bool)>::call(
DirectCaller(nullptr, improbable_feat_sps_wrapper),
[&](Expected<Error> E) {
DidRun = true;
cantFail(cantFail(std::move(E)));
},
true);

EXPECT_TRUE(DidRun);
}

TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningErrorFailureCase) {
std::string ErrMsg;
SPSWrapperFunction<SPSError(bool)>::call(
DirectCaller(nullptr, improbable_feat_sps_wrapper),
[&](Expected<Error> E) { ErrMsg = toString(cantFail(std::move(E))); },
false);

EXPECT_EQ(ErrMsg, "crushed by boulder");
}

static void halve_number_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes) {
SPSWrapperFunction<SPSExpected<int32_t>(int32_t)>::handle(
Session, CallCtx, Return, ArgBytes,
[](move_only_function<void(Expected<int32_t>)> Return, int N) {
if (N % 2 == 0)
Return(N >> 1);
else
Return(make_error<StringError>("N is not a multiple of 2"));
});
}

TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningExpectedSuccessCase) {
int32_t Result = 0;
SPSWrapperFunction<SPSExpected<int32_t>(int32_t)>::call(
DirectCaller(nullptr, halve_number_sps_wrapper),
[&](Expected<Expected<int32_t>> R) {
Result = cantFail(cantFail(std::move(R)));
},
2);

EXPECT_EQ(Result, 1);
}

TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningExpectedFailureCase) {
std::string ErrMsg;
SPSWrapperFunction<SPSExpected<int32_t>(int32_t)>::call(
DirectCaller(nullptr, halve_number_sps_wrapper),
[&](Expected<Expected<int32_t>> R) {
ErrMsg = toString(cantFail(std::move(R)).takeError());
},
3);

EXPECT_EQ(ErrMsg, "N is not a multiple of 2");
}
Loading