diff --git a/orc-rt/include/orc-rt/SPSWrapperFunction.h b/orc-rt/include/orc-rt/SPSWrapperFunction.h index 3ea6406b69a37..14a3d8e3d6ad6 100644 --- a/orc-rt/include/orc-rt/SPSWrapperFunction.h +++ b/orc-rt/include/orc-rt/SPSWrapperFunction.h @@ -21,8 +21,10 @@ namespace orc_rt { namespace detail { template struct WFSPSHelper { - template - std::optional serialize(const ArgTs &...Args) { +private: + template + std::optional + serializeImpl(const SerializableArgTs &...Args) { auto R = WrapperFunctionBuffer::allocate(SPSArgList::size(Args...)); SPSOutputBuffer OB(R.data(), R.size()); @@ -31,16 +33,61 @@ template struct WFSPSHelper { return std::move(R); } + template static const T &toSerializable(const T &Arg) noexcept { + return Arg; + } + + static SPSSerializableError toSerializable(Error Err) noexcept { + return SPSSerializableError(std::move(Err)); + } + + template + static SPSSerializableExpected toSerializable(Expected Arg) noexcept { + return SPSSerializableExpected(std::move(Arg)); + } + + template struct DeserializableTuple; + + template struct DeserializableTuple> { + typedef std::tuple< + std::decay_t()))>...> + type; + }; + + template + using DeserializableTuple_t = typename DeserializableTuple::type; + + template static T fromSerializable(T &&Arg) noexcept { + return Arg; + } + + static Error fromSerializable(SPSSerializableError Err) noexcept { + return Err.toError(); + } + + template + static Expected fromSerializable(SPSSerializableExpected Val) noexcept { + return Val.toExpected(); + } + +public: + template + std::optional serialize(ArgTs &&...Args) { + return serializeImpl(toSerializable(std::forward(Args))...); + } + template std::optional deserialize(WrapperFunctionBuffer ArgBytes) { assert(!ArgBytes.getOutOfBandError() && "Should not attempt to deserialize out-of-band error"); SPSInputBuffer IB(ArgBytes.data(), ArgBytes.size()); - ArgTuple Args; - if (!SPSSerializationTraits, ArgTuple>::deserialize( - IB, Args)) + DeserializableTuple_t Args; + if (!SPSSerializationTraits, + decltype(Args)>::deserialize(IB, Args)) return std::nullopt; - return Args; + return std::apply( + [](auto &&...A) { return ArgTuple(fromSerializable(A)...); }, + std::move(Args)); } }; diff --git a/orc-rt/include/orc-rt/WrapperFunction.h b/orc-rt/include/orc-rt/WrapperFunction.h index 233c3b21e041d..ca165db7188b4 100644 --- a/orc-rt/include/orc-rt/WrapperFunction.h +++ b/orc-rt/include/orc-rt/WrapperFunction.h @@ -168,7 +168,8 @@ struct ResultDeserializer>, Serializer> { Serializer &S) { if (auto Val = S.result().template deserialize>( std::move(ResultBytes))) - return std::move(std::get<0>(*Val)); + return Expected(std::move(std::get<0>(*Val)), + ForceExpectedSuccessValue()); else return make_error("Could not deserialize result"); } diff --git a/orc-rt/unittests/SPSWrapperFunctionTest.cpp b/orc-rt/unittests/SPSWrapperFunctionTest.cpp index 0b65515120b7f..c0c86ff8715ce 100644 --- a/orc-rt/unittests/SPSWrapperFunctionTest.cpp +++ b/orc-rt/unittests/SPSWrapperFunctionTest.cpp @@ -144,3 +144,77 @@ TEST(SPSWrapperFunctionUtilsTest, TestBinaryOpViaFunctionPointer) { [&](Expected R) { Result = cantFail(std::move(R)); }, 41, 1); EXPECT_EQ(Result, 42); } + +static void improbable_feat_sps_wrapper(orc_rt_SessionRef Session, + void *CallCtx, + orc_rt_WrapperFunctionReturn Return, + orc_rt_WrapperFunctionBuffer ArgBytes) { + SPSWrapperFunction::handle( + Session, CallCtx, Return, ArgBytes, + [](move_only_function Return, bool LuckyHat) { + if (LuckyHat) + Return(Error::success()); + else + Return(make_error("crushed by boulder")); + }); +} + +TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningErrorSuccessCase) { + bool DidRun = false; + SPSWrapperFunction::call( + DirectCaller(nullptr, improbable_feat_sps_wrapper), + [&](Expected E) { + DidRun = true; + cantFail(cantFail(std::move(E))); + }, + true); + + EXPECT_TRUE(DidRun); +} + +TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningErrorFailureCase) { + std::string ErrMsg; + SPSWrapperFunction::call( + DirectCaller(nullptr, improbable_feat_sps_wrapper), + [&](Expected E) { ErrMsg = toString(cantFail(std::move(E))); }, + false); + + EXPECT_EQ(ErrMsg, "crushed by boulder"); +} + +static void halve_number_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx, + orc_rt_WrapperFunctionReturn Return, + orc_rt_WrapperFunctionBuffer ArgBytes) { + SPSWrapperFunction(int32_t)>::handle( + Session, CallCtx, Return, ArgBytes, + [](move_only_function)> Return, int N) { + if (N % 2 == 0) + Return(N >> 1); + else + Return(make_error("N is not a multiple of 2")); + }); +} + +TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningExpectedSuccessCase) { + int32_t Result = 0; + SPSWrapperFunction(int32_t)>::call( + DirectCaller(nullptr, halve_number_sps_wrapper), + [&](Expected> R) { + Result = cantFail(cantFail(std::move(R))); + }, + 2); + + EXPECT_EQ(Result, 1); +} + +TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningExpectedFailureCase) { + std::string ErrMsg; + SPSWrapperFunction(int32_t)>::call( + DirectCaller(nullptr, halve_number_sps_wrapper), + [&](Expected> R) { + ErrMsg = toString(cantFail(std::move(R)).takeError()); + }, + 3); + + EXPECT_EQ(ErrMsg, "N is not a multiple of 2"); +}