diff --git a/src/client.h b/src/client.h index 8d95748..04f26da 100644 --- a/src/client.h +++ b/src/client.h @@ -41,7 +41,8 @@ class RPCClient { } // blocking call - while (!get_response(msg_id_wait, result)){ + RpcError tmp_error; + while (!get_response(msg_id_wait, result, tmp_error)) { //delay(1); } @@ -60,18 +61,18 @@ class RPCClient { } template - bool get_response(const uint32_t wait_id, RType& result) { - RpcError tmp_error; + bool get_response(const uint32_t wait_id, RType& result, RpcError& error) { decoder->decode(); - if (decoder->get_response(wait_id, result, tmp_error)) { - lastError.code = tmp_error.code; - lastError.traceback = tmp_error.traceback; + if (decoder->get_response(wait_id, result, error)) { + lastError.copy(error); return true; } return false; } + uint32_t get_discarded_packets() const {return decoder->get_discarded_packets();} + }; #endif //RPCLITE_CLIENT_H diff --git a/src/decoder.h b/src/decoder.h index f3ab3d3..883abeb 100644 --- a/src/decoder.h +++ b/src/decoder.h @@ -62,29 +62,54 @@ class RpcDecoder { MsgPack::Unpacker unpacker; unpacker.clear(); - size_t res_size = get_packet_size(); - if (!unpacker.feed(_raw_buffer, res_size)) return false; + if (!unpacker.feed(_raw_buffer, _packet_size)) return false; MsgPack::arr_size_t resp_size; int resp_type; uint32_t resp_id; if (!unpacker.deserialize(resp_size, resp_type, resp_id)) return false; - if (resp_size.size() != RESPONSE_SIZE) return false; - if (resp_type != RESP_MSG) return false; + + // ReSharper disable once CppDFAUnreachableCode if (resp_id != msg_id) return false; + // msg_id OK packet will be consumed. + if (resp_type != RESP_MSG) { + // This should never happen + error.code = PARSING_ERR; + error.traceback = "Unexpected response type"; + discard(); + return true; + } + + if (resp_size.size() != RESPONSE_SIZE) { + // This should never happen + error.code = PARSING_ERR; + error.traceback = "Unexpected RPC response size"; + discard(); + return true; + } + MsgPack::object::nil_t nil; if (unpacker.unpackable(nil)){ // No error - if (!unpacker.deserialize(nil, result)) return false; + if (!unpacker.deserialize(nil, result)) { + error.code = PARSING_ERR; + error.traceback = "Result not parsable (check type)"; + discard(); + return true; + } } else { // RPC returned an error - if (!unpacker.deserialize(error, nil)) return false; + if (!unpacker.deserialize(error, nil)) { + error.code = PARSING_ERR; + error.traceback = "RPC Error not parsable (check type)"; + discard(); + return true; + } } + consume(_packet_size); reset_packet(); - consume(res_size); return true; - } bool send_response(const MsgPack::Packer& packer) const { @@ -103,8 +128,7 @@ class RpcDecoder { unpacker.clear(); if (!unpacker.feed(_raw_buffer, _packet_size)) { // feed should not fail at this point - consume(_packet_size); - reset_packet(); + discard(); return ""; }; @@ -113,27 +137,24 @@ class RpcDecoder { MsgPack::arr_size_t req_size; if (!unpacker.deserialize(req_size, msg_type)) { - consume(_packet_size); - reset_packet(); + discard(); return ""; // Header not unpackable } + // ReSharper disable once CppDFAUnreachableCode if (msg_type == CALL_MSG && req_size.size() == REQUEST_SIZE) { uint32_t msg_id; if (!unpacker.deserialize(msg_id, method)) { - consume(_packet_size); - reset_packet(); + discard(); return ""; // Method not unpackable } } else if (msg_type == NOTIFY_MSG && req_size.size() == NOTIFY_SIZE) { if (!unpacker.deserialize(method)) { - consume(_packet_size); - reset_packet(); + discard(); return ""; // Method not unpackable } } else { - consume(_packet_size); - reset_packet(); + discard(); return ""; // Invalid request size/type } @@ -183,11 +204,13 @@ class RpcDecoder { if (type != CALL_MSG && type != RESP_MSG && type != NOTIFY_MSG) { consume(bytes_checked); + _discarded_packets++; break; // Not a valid RPC type (could be type=WRONG_MSG) } if ((type == CALL_MSG && container_size != REQUEST_SIZE) || (type == RESP_MSG && container_size != RESPONSE_SIZE) || (type == NOTIFY_MSG && container_size != NOTIFY_SIZE)) { consume(bytes_checked); + _discarded_packets++; break; // Not a valid RPC format } @@ -210,6 +233,8 @@ class RpcDecoder { size_t size() const {return _bytes_stored;} + uint32_t get_discarded_packets() const {return _discarded_packets;} + friend class DecoderTester; private: @@ -219,6 +244,7 @@ class RpcDecoder { int _packet_type = NO_MSG; size_t _packet_size = 0; uint32_t _msg_id = 0; + uint32_t _discarded_packets = 0; bool buffer_full() const { return _bytes_stored == BufferSize; } @@ -252,6 +278,11 @@ class RpcDecoder { return consume(packet_size); } + void discard() { + consume(_packet_size); + reset_packet(); + _discarded_packets++; + } void reset_packet() { _packet_type = NO_MSG; diff --git a/src/error.h b/src/error.h index 925b30a..c98cb7d 100644 --- a/src/error.h +++ b/src/error.h @@ -17,6 +17,7 @@ #include "MsgPack.h" #define NO_ERR 0x00 +#define PARSING_ERR 0xFC #define MALFORMED_CALL_ERR 0xFD #define FUNCTION_NOT_FOUND_ERR 0xFE #define GENERIC_ERR 0xFF @@ -34,6 +35,11 @@ struct RpcError { RpcError(const int c, MsgPack::str_t tb) : code(c), traceback(std::move(tb)) {} + void copy(const RpcError& err) { + code = err.code; + traceback = err.traceback; + } + MSGPACK_DEFINE(code, traceback); // -> [code, traceback] }; diff --git a/src/server.h b/src/server.h index 3200b40..f3a5803 100644 --- a/src/server.h +++ b/src/server.h @@ -87,6 +87,8 @@ class RPCServer { } + uint32_t get_discarded_packets() const {return decoder->get_discarded_packets();} + private: RpcDecoder<>* decoder = nullptr; RpcFunctionDispatcher dispatcher{};