diff --git a/include/aws/event-stream/event_stream.h b/include/aws/event-stream/event_stream.h index 41302db..a9b9b2b 100644 --- a/include/aws/event-stream/event_stream.h +++ b/include/aws/event-stream/event_stream.h @@ -20,6 +20,9 @@ /* max header size is 128kb */ #define AWS_EVENT_STREAM_MAX_HEADERS_SIZE (128 * 1024) +/* Max header name length is 127 bytes */ +#define AWS_EVENT_STREAM_HEADER_NAME_LEN_MAX (INT8_MAX) + enum aws_event_stream_errors { AWS_ERROR_EVENT_STREAM_BUFFER_LENGTH_MISMATCH = AWS_ERROR_ENUM_BEGIN_RANGE(AWS_C_EVENT_STREAM_PACKAGE_ID), AWS_ERROR_EVENT_STREAM_INSUFFICIENT_BUFFER_LEN, @@ -54,8 +57,7 @@ struct aws_event_stream_message_prelude { struct aws_event_stream_message { struct aws_allocator *alloc; - uint8_t *message_buffer; - uint8_t owns_buffer; + struct aws_byte_buf message_buffer; }; #define AWS_EVENT_STREAM_PRELUDE_LENGTH (uint32_t)(sizeof(uint32_t) + sizeof(uint32_t) + sizeof(uint32_t)) @@ -76,6 +78,7 @@ enum aws_event_stream_header_value_type { AWS_EVENT_STREAM_HEADER_UUID }; +static const uint16_t UUID_LEN = 16U; struct aws_event_stream_header_value_pair { uint8_t header_name_len; char header_name[INT8_MAX]; @@ -244,6 +247,22 @@ AWS_EVENT_STREAM_API const uint8_t *aws_event_stream_message_buffer(const struct AWS_EVENT_STREAM_API uint32_t aws_event_stream_compute_headers_required_buffer_len(const struct aws_array_list *headers); +/** + * Writes headers to buf assuming buf is large enough to hold the data. Prefer this function over the unsafe variant + * 'aws_event_stream_write_headers_to_buffer'. + * + * Returns AWS_OP_SUCCESS if the headers were successfully and completely written and AWS_OP_ERR otherwise. + */ +AWS_EVENT_STREAM_API int aws_event_stream_write_headers_to_buffer_safe( + const struct aws_array_list *headers, + struct aws_byte_buf *buf); + +/** + * Deprecated in favor of 'aws_event_stream_write_headers_to_buffer_safe' as this API is unsafe. + * + * Writes headers to buffer and returns the length of bytes written to buffer. Assumes buffer is large enough to + * store the headers. + */ AWS_EVENT_STREAM_API size_t aws_event_stream_write_headers_to_buffer(const struct aws_array_list *headers, uint8_t *buffer); diff --git a/source/event_stream.c b/source/event_stream.c index 224690f..bac892e 100644 --- a/source/event_stream.c +++ b/source/event_stream.c @@ -7,6 +7,7 @@ #include +#include #include #include @@ -128,81 +129,100 @@ uint32_t aws_event_stream_compute_headers_required_buffer_len(const struct aws_a struct aws_event_stream_header_value_pair *header = NULL; aws_array_list_get_at_ptr(headers, (void **)&header, i); - - headers_len += sizeof(header->header_name_len) + header->header_name_len + 1; + AWS_FATAL_ASSERT( + !aws_add_size_checked(headers_len, sizeof(header->header_name_len), &headers_len) && + "integer overflow occurred computing total headers length."); + AWS_FATAL_ASSERT( + !aws_add_size_checked(headers_len, header->header_name_len + 1, &headers_len) && + "integer overflow occurred computing total headers length."); if (header->header_value_type == AWS_EVENT_STREAM_HEADER_STRING || header->header_value_type == AWS_EVENT_STREAM_HEADER_BYTE_BUF) { - headers_len += sizeof(header->header_value_len); + AWS_FATAL_ASSERT( + !aws_add_size_checked(headers_len, sizeof(header->header_value_len), &headers_len) && + "integer overflow occurred computing total headers length."); } if (header->header_value_type != AWS_EVENT_STREAM_HEADER_BOOL_FALSE && header->header_value_type != AWS_EVENT_STREAM_HEADER_BOOL_TRUE) { - headers_len += header->header_value_len; + AWS_FATAL_ASSERT( + !aws_add_size_checked(headers_len, header->header_value_len, &headers_len) && + "integer overflow occurred computing total headers length."); } } return (uint32_t)headers_len; } -/* adds the headers represented in the headers list to the buffer. - returns the new buffer offset for use elsewhere. Assumes buffer length is at least the length of the return value - from compute_headers_length() */ -size_t aws_event_stream_write_headers_to_buffer(const struct aws_array_list *headers, uint8_t *buffer) { +int aws_event_stream_write_headers_to_buffer_safe(const struct aws_array_list *headers, struct aws_byte_buf *buf) { + AWS_FATAL_PRECONDITION(buf); + if (!headers || !aws_array_list_length(headers)) { - return 0; + return AWS_OP_SUCCESS; } size_t headers_count = aws_array_list_length(headers); - uint8_t *buffer_alias = buffer; for (size_t i = 0; i < headers_count; ++i) { struct aws_event_stream_header_value_pair *header = NULL; aws_array_list_get_at_ptr(headers, (void **)&header, i); - *buffer_alias = (uint8_t)header->header_name_len; - buffer_alias++; - memcpy(buffer_alias, header->header_name, (size_t)header->header_name_len); - buffer_alias += header->header_name_len; - *buffer_alias = (uint8_t)header->header_value_type; - buffer_alias++; + AWS_RETURN_ERROR_IF( + aws_byte_buf_write_u8(buf, header->header_name_len), AWS_ERROR_EVENT_STREAM_INSUFFICIENT_BUFFER_LEN); + AWS_RETURN_ERROR_IF( + aws_byte_buf_write(buf, (uint8_t *)header->header_name, (size_t)header->header_name_len), + AWS_ERROR_EVENT_STREAM_INSUFFICIENT_BUFFER_LEN); + AWS_RETURN_ERROR_IF( + aws_byte_buf_write_u8(buf, (uint8_t)header->header_value_type), + AWS_ERROR_EVENT_STREAM_INSUFFICIENT_BUFFER_LEN); + switch (header->header_value_type) { case AWS_EVENT_STREAM_HEADER_BOOL_FALSE: case AWS_EVENT_STREAM_HEADER_BOOL_TRUE: break; + /* additions of integers here assume the endianness conversion has already happened */ case AWS_EVENT_STREAM_HEADER_BYTE: - *buffer_alias = header->header_value.static_val[0]; - buffer_alias++; - break; - /* additions of integers here assume the endianness conversion has already happened */ case AWS_EVENT_STREAM_HEADER_INT16: - memcpy(buffer_alias, header->header_value.static_val, sizeof(uint16_t)); - buffer_alias += sizeof(uint16_t); - break; case AWS_EVENT_STREAM_HEADER_INT32: - memcpy(buffer_alias, header->header_value.static_val, sizeof(uint32_t)); - buffer_alias += sizeof(uint32_t); - break; case AWS_EVENT_STREAM_HEADER_INT64: case AWS_EVENT_STREAM_HEADER_TIMESTAMP: - memcpy(buffer_alias, header->header_value.static_val, sizeof(uint64_t)); - buffer_alias += sizeof(uint64_t); + case AWS_EVENT_STREAM_HEADER_UUID: + AWS_RETURN_ERROR_IF( + aws_byte_buf_write(buf, header->header_value.static_val, header->header_value_len), + AWS_ERROR_EVENT_STREAM_INSUFFICIENT_BUFFER_LEN); break; case AWS_EVENT_STREAM_HEADER_BYTE_BUF: case AWS_EVENT_STREAM_HEADER_STRING: - aws_write_u16(header->header_value_len, buffer_alias); - buffer_alias += sizeof(uint16_t); - memcpy(buffer_alias, header->header_value.variable_len_val, header->header_value_len); - buffer_alias += header->header_value_len; + AWS_RETURN_ERROR_IF( + aws_byte_buf_write_be16(buf, header->header_value_len), + AWS_ERROR_EVENT_STREAM_INSUFFICIENT_BUFFER_LEN); + AWS_RETURN_ERROR_IF( + aws_byte_buf_write(buf, header->header_value.variable_len_val, header->header_value_len), + AWS_ERROR_EVENT_STREAM_INSUFFICIENT_BUFFER_LEN); break; - case AWS_EVENT_STREAM_HEADER_UUID: - memcpy(buffer_alias, header->header_value.static_val, 16); - buffer_alias += header->header_value_len; + default: + AWS_FATAL_ASSERT(false && !"Unknown header type!"); break; } } - return buffer_alias - buffer; + return AWS_OP_SUCCESS; +} + +/* adds the headers represented in the headers list to the buffer. + returns the new buffer offset for use elsewhere. Assumes buffer length is at least the length of the return value + from compute_headers_length() */ +size_t aws_event_stream_write_headers_to_buffer(const struct aws_array_list *headers, uint8_t *buffer) { + AWS_FATAL_PRECONDITION(buffer); + + uint32_t min_buffer_len_assumption = aws_event_stream_compute_headers_required_buffer_len(headers); + struct aws_byte_buf safer_buf = aws_byte_buf_from_array(buffer, min_buffer_len_assumption); + + if (aws_event_stream_write_headers_to_buffer_safe(headers, &safer_buf)) { + return 0; + } + + return safer_buf.len; } int aws_event_stream_read_headers_from_buffer( @@ -210,23 +230,28 @@ int aws_event_stream_read_headers_from_buffer( const uint8_t *buffer, size_t headers_len) { - if (AWS_UNLIKELY(headers_len > AWS_EVENT_STREAM_MAX_HEADERS_SIZE)) { + AWS_FATAL_PRECONDITION(headers); + AWS_FATAL_PRECONDITION(buffer); + + if (AWS_UNLIKELY(headers_len > (size_t)AWS_EVENT_STREAM_MAX_HEADERS_SIZE)) { return aws_raise_error(AWS_ERROR_EVENT_STREAM_MESSAGE_FIELD_SIZE_EXCEEDED); } + struct aws_byte_cursor buffer_cur = aws_byte_cursor_from_array(buffer, headers_len); /* iterate the buffer per header. */ - const uint8_t *buffer_start = buffer; - while ((size_t)(buffer - buffer_start) < headers_len) { + while (buffer_cur.len) { struct aws_event_stream_header_value_pair header; AWS_ZERO_STRUCT(header); /* get the header info from the buffer, make sure to increment buffer offset. */ - header.header_name_len = *buffer; - buffer += sizeof(header.header_name_len); - memcpy((void *)header.header_name, buffer, (size_t)header.header_name_len); - buffer += header.header_name_len; - header.header_value_type = (enum aws_event_stream_header_value_type) * buffer; - buffer++; + aws_byte_cursor_read_u8(&buffer_cur, &header.header_name_len); + AWS_RETURN_ERROR_IF(header.header_name_len <= INT8_MAX, AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN); + AWS_RETURN_ERROR_IF( + aws_byte_cursor_read(&buffer_cur, header.header_name, (size_t)header.header_name_len), + AWS_ERROR_EVENT_STREAM_BUFFER_LENGTH_MISMATCH); + AWS_RETURN_ERROR_IF( + aws_byte_cursor_read_u8(&buffer_cur, (uint8_t *)&header.header_value_type), + AWS_ERROR_EVENT_STREAM_BUFFER_LENGTH_MISMATCH); switch (header.header_value_type) { case AWS_EVENT_STREAM_HEADER_BOOL_FALSE: @@ -239,36 +264,46 @@ int aws_event_stream_read_headers_from_buffer( break; case AWS_EVENT_STREAM_HEADER_BYTE: header.header_value_len = sizeof(uint8_t); - header.header_value.static_val[0] = *buffer; - buffer++; + AWS_RETURN_ERROR_IF( + aws_byte_cursor_read(&buffer_cur, header.header_value.static_val, header.header_value_len), + AWS_ERROR_EVENT_STREAM_BUFFER_LENGTH_MISMATCH); break; case AWS_EVENT_STREAM_HEADER_INT16: header.header_value_len = sizeof(uint16_t); - memcpy(header.header_value.static_val, buffer, sizeof(uint16_t)); - buffer += sizeof(uint16_t); + AWS_RETURN_ERROR_IF( + aws_byte_cursor_read(&buffer_cur, header.header_value.static_val, header.header_value_len), + AWS_ERROR_EVENT_STREAM_BUFFER_LENGTH_MISMATCH); break; case AWS_EVENT_STREAM_HEADER_INT32: header.header_value_len = sizeof(uint32_t); - memcpy(header.header_value.static_val, buffer, sizeof(uint32_t)); - buffer += sizeof(uint32_t); + AWS_RETURN_ERROR_IF( + aws_byte_cursor_read(&buffer_cur, header.header_value.static_val, header.header_value_len), + AWS_ERROR_EVENT_STREAM_BUFFER_LENGTH_MISMATCH); break; case AWS_EVENT_STREAM_HEADER_INT64: case AWS_EVENT_STREAM_HEADER_TIMESTAMP: header.header_value_len = sizeof(uint64_t); - memcpy(header.header_value.static_val, buffer, sizeof(uint64_t)); - buffer += sizeof(uint64_t); + AWS_RETURN_ERROR_IF( + aws_byte_cursor_read(&buffer_cur, header.header_value.static_val, header.header_value_len), + AWS_ERROR_EVENT_STREAM_BUFFER_LENGTH_MISMATCH); break; case AWS_EVENT_STREAM_HEADER_BYTE_BUF: case AWS_EVENT_STREAM_HEADER_STRING: - header.header_value_len = aws_read_u16(buffer); - buffer += sizeof(header.header_value_len); - header.header_value.variable_len_val = (uint8_t *)buffer; - buffer += header.header_value_len; + AWS_RETURN_ERROR_IF( + aws_byte_cursor_read_be16(&buffer_cur, &header.header_value_len), + AWS_ERROR_EVENT_STREAM_BUFFER_LENGTH_MISMATCH); + AWS_RETURN_ERROR_IF( + header.header_value_len <= INT16_MAX, AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN); + AWS_RETURN_ERROR_IF( + buffer_cur.len >= header.header_value_len, AWS_ERROR_EVENT_STREAM_BUFFER_LENGTH_MISMATCH); + header.header_value.variable_len_val = (uint8_t *)buffer_cur.ptr; + aws_byte_cursor_advance(&buffer_cur, header.header_value_len); break; case AWS_EVENT_STREAM_HEADER_UUID: - header.header_value_len = 16; - memcpy(header.header_value.static_val, buffer, 16); - buffer += header.header_value_len; + header.header_value_len = UUID_LEN; + AWS_RETURN_ERROR_IF( + aws_byte_cursor_read(&buffer_cur, header.header_value.static_val, UUID_LEN), + AWS_ERROR_EVENT_STREAM_BUFFER_LENGTH_MISMATCH); break; } @@ -288,6 +323,8 @@ int aws_event_stream_message_init( struct aws_allocator *alloc, struct aws_array_list *headers, struct aws_byte_buf *payload) { + AWS_FATAL_PRECONDITION(message); + AWS_FATAL_PRECONDITION(alloc); size_t payload_len = payload ? payload->len : 0; @@ -309,39 +346,33 @@ int aws_event_stream_message_init( } message->alloc = alloc; - message->message_buffer = aws_mem_acquire(message->alloc, total_length); + aws_byte_buf_init(&message->message_buffer, message->alloc, total_length); - if (message->message_buffer) { - message->owns_buffer = 1; - aws_write_u32(total_length, message->message_buffer); - uint8_t *buffer_offset = message->message_buffer + sizeof(total_length); - aws_write_u32(headers_length, buffer_offset); - buffer_offset += sizeof(headers_length); + aws_byte_buf_write_be32(&message->message_buffer, total_length); + aws_byte_buf_write_be32(&message->message_buffer, headers_length); - uint32_t running_crc = - aws_checksums_crc32(message->message_buffer, (int)(buffer_offset - message->message_buffer), 0); + uint32_t running_crc = aws_checksums_crc32(message->message_buffer.buffer, (int)message->message_buffer.len, 0); - const uint8_t *message_crc_boundary_start = buffer_offset; - aws_write_u32(running_crc, buffer_offset); - buffer_offset += sizeof(running_crc); - - if (headers_length) { - buffer_offset += aws_event_stream_write_headers_to_buffer(headers, buffer_offset); - } + const uint8_t *pre_prelude_marker = message->message_buffer.buffer + message->message_buffer.len; + size_t pre_prelude_position_marker = message->message_buffer.len; + aws_byte_buf_write_be32(&message->message_buffer, running_crc); - if (payload) { - memcpy(buffer_offset, payload->buffer, payload->len); - buffer_offset += payload->len; + if (headers_length) { + if (aws_event_stream_write_headers_to_buffer_safe(headers, &message->message_buffer)) { + aws_event_stream_message_clean_up(message); + return AWS_OP_ERR; } + } - running_crc = aws_checksums_crc32( - message_crc_boundary_start, (int)(buffer_offset - message_crc_boundary_start), running_crc); - aws_write_u32(running_crc, buffer_offset); - - return AWS_OP_SUCCESS; + if (payload) { + aws_byte_buf_write_from_whole_buffer(&message->message_buffer, *payload); } - return aws_raise_error(AWS_ERROR_OOM); + running_crc = aws_checksums_crc32( + pre_prelude_marker, (int)(message->message_buffer.len - pre_prelude_position_marker), running_crc); + aws_byte_buf_write_be32(&message->message_buffer, running_crc); + + return AWS_OP_SUCCESS; } /* add buffer to the message (non-owning). Verify buffer crcs and that length fields are reasonable. */ @@ -349,16 +380,20 @@ int aws_event_stream_message_from_buffer( struct aws_event_stream_message *message, struct aws_allocator *alloc, struct aws_byte_buf *buffer) { - AWS_ASSERT(buffer); + AWS_FATAL_PRECONDITION(message); + AWS_FATAL_PRECONDITION(alloc); + AWS_FATAL_PRECONDITION(buffer); message->alloc = alloc; - message->owns_buffer = 0; if (AWS_UNLIKELY(buffer->len < AWS_EVENT_STREAM_PRELUDE_LENGTH + AWS_EVENT_STREAM_TRAILER_LENGTH)) { return aws_raise_error(AWS_ERROR_EVENT_STREAM_BUFFER_LENGTH_MISMATCH); } - uint32_t message_length = aws_read_u32(buffer->buffer + TOTAL_LEN_OFFSET); + struct aws_byte_cursor parsing_cur = aws_byte_cursor_from_buf(buffer); + + uint32_t message_length = 0; + aws_byte_cursor_read_be32(&parsing_cur, &message_length); if (AWS_UNLIKELY(message_length != buffer->len)) { return aws_raise_error(AWS_ERROR_EVENT_STREAM_BUFFER_LENGTH_MISMATCH); @@ -367,17 +402,21 @@ int aws_event_stream_message_from_buffer( if (AWS_UNLIKELY(message_length > AWS_EVENT_STREAM_MAX_MESSAGE_SIZE)) { return aws_raise_error(AWS_ERROR_EVENT_STREAM_MESSAGE_FIELD_SIZE_EXCEEDED); } - + /* skip the headers for the moment, we'll handle those later. */ + aws_byte_cursor_advance(&parsing_cur, sizeof(uint32_t)); uint32_t running_crc = aws_checksums_crc32(buffer->buffer, (int)PRELUDE_CRC_OFFSET, 0); - uint32_t prelude_crc = aws_read_u32(buffer->buffer + PRELUDE_CRC_OFFSET); + uint32_t prelude_crc = 0; + const uint8_t *start_of_payload_checksum = parsing_cur.ptr; + size_t start_of_payload_checksum_pos = PRELUDE_CRC_OFFSET; + aws_byte_cursor_read_be32(&parsing_cur, &prelude_crc); if (running_crc != prelude_crc) { return aws_raise_error(AWS_ERROR_EVENT_STREAM_PRELUDE_CHECKSUM_FAILURE); } running_crc = aws_checksums_crc32( - buffer->buffer + PRELUDE_CRC_OFFSET, - (int)(message_length - PRELUDE_CRC_OFFSET - AWS_EVENT_STREAM_TRAILER_LENGTH), + start_of_payload_checksum, + (int)(message_length - start_of_payload_checksum_pos - AWS_EVENT_STREAM_TRAILER_LENGTH), running_crc); uint32_t message_crc = aws_read_u32(buffer->buffer + message_length - AWS_EVENT_STREAM_TRAILER_LENGTH); @@ -385,11 +424,14 @@ int aws_event_stream_message_from_buffer( return aws_raise_error(AWS_ERROR_EVENT_STREAM_MESSAGE_CHECKSUM_FAILURE); } - message->message_buffer = buffer->buffer; + message->message_buffer = *buffer; + /* we don't own this buffer, this is a zero allocation/copy path. Setting allocator to null will prevent the + * clean_up from attempting to free it */ + message->message_buffer.allocator = NULL; if (aws_event_stream_message_headers_len(message) > message_length - AWS_EVENT_STREAM_PRELUDE_LENGTH - AWS_EVENT_STREAM_TRAILER_LENGTH) { - message->message_buffer = 0; + AWS_ZERO_STRUCT(message->message_buffer); return aws_raise_error(AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN); } @@ -404,17 +446,9 @@ int aws_event_stream_message_from_buffer_copy( int parse_value = aws_event_stream_message_from_buffer(message, alloc, (struct aws_byte_buf *)buffer); if (!parse_value) { - message->message_buffer = aws_mem_acquire(alloc, buffer->len); - - if (message->message_buffer) { - memcpy(message->message_buffer, buffer->buffer, buffer->len); - message->alloc = alloc; - message->owns_buffer = 1; - - return AWS_OP_SUCCESS; - } - - return aws_raise_error(AWS_ERROR_OOM); + aws_byte_buf_init_copy(&message->message_buffer, alloc, buffer); + message->alloc = alloc; + return AWS_OP_SUCCESS; } return parse_value; @@ -422,47 +456,75 @@ int aws_event_stream_message_from_buffer_copy( /* if buffer is owned, release the memory. */ void aws_event_stream_message_clean_up(struct aws_event_stream_message *message) { - if (message->message_buffer && message->owns_buffer) { - aws_mem_release(message->alloc, message->message_buffer); - } + aws_byte_buf_clean_up(&message->message_buffer); } uint32_t aws_event_stream_message_total_length(const struct aws_event_stream_message *message) { - return aws_read_u32(message->message_buffer + TOTAL_LEN_OFFSET); + struct aws_byte_cursor read_cur = aws_byte_cursor_from_buf(&message->message_buffer); + aws_byte_cursor_advance(&read_cur, TOTAL_LEN_OFFSET); + uint32_t total_len = 0; + aws_byte_cursor_read_be32(&read_cur, &total_len); + + return total_len; } uint32_t aws_event_stream_message_headers_len(const struct aws_event_stream_message *message) { - return aws_read_u32(message->message_buffer + HEADER_LEN_OFFSET); + struct aws_byte_cursor read_cur = aws_byte_cursor_from_buf(&message->message_buffer); + aws_byte_cursor_advance(&read_cur, HEADER_LEN_OFFSET); + + uint32_t headers_len = 0; + aws_byte_cursor_read_be32(&read_cur, &headers_len); + + return headers_len; } uint32_t aws_event_stream_message_prelude_crc(const struct aws_event_stream_message *message) { - return aws_read_u32(message->message_buffer + PRELUDE_CRC_OFFSET); + struct aws_byte_cursor read_cur = aws_byte_cursor_from_buf(&message->message_buffer); + aws_byte_cursor_advance(&read_cur, PRELUDE_CRC_OFFSET); + + uint32_t prelude_crc = 0; + aws_byte_cursor_read_be32(&read_cur, &prelude_crc); + + return prelude_crc; } int aws_event_stream_message_headers(const struct aws_event_stream_message *message, struct aws_array_list *headers) { + struct aws_byte_cursor read_cur = aws_byte_cursor_from_buf(&message->message_buffer); + aws_byte_cursor_advance(&read_cur, AWS_EVENT_STREAM_PRELUDE_LENGTH); + return aws_event_stream_read_headers_from_buffer( - headers, - message->message_buffer + AWS_EVENT_STREAM_PRELUDE_LENGTH, - aws_event_stream_message_headers_len(message)); + headers, read_cur.ptr, aws_event_stream_message_headers_len(message)); } const uint8_t *aws_event_stream_message_payload(const struct aws_event_stream_message *message) { - return message->message_buffer + AWS_EVENT_STREAM_PRELUDE_LENGTH + aws_event_stream_message_headers_len(message); + AWS_FATAL_PRECONDITION(message); + struct aws_byte_cursor read_cur = aws_byte_cursor_from_buf(&message->message_buffer); + aws_byte_cursor_advance(&read_cur, AWS_EVENT_STREAM_PRELUDE_LENGTH + aws_event_stream_message_headers_len(message)); + return read_cur.ptr; } uint32_t aws_event_stream_message_payload_len(const struct aws_event_stream_message *message) { + AWS_FATAL_PRECONDITION(message); return aws_event_stream_message_total_length(message) - (AWS_EVENT_STREAM_PRELUDE_LENGTH + aws_event_stream_message_headers_len(message) + AWS_EVENT_STREAM_TRAILER_LENGTH); } uint32_t aws_event_stream_message_message_crc(const struct aws_event_stream_message *message) { - return aws_read_u32( - message->message_buffer + (aws_event_stream_message_total_length(message) - AWS_EVENT_STREAM_TRAILER_LENGTH)); + AWS_FATAL_PRECONDITION(message); + struct aws_byte_cursor read_cur = aws_byte_cursor_from_buf(&message->message_buffer); + aws_byte_cursor_advance( + &read_cur, aws_event_stream_message_total_length(message) - AWS_EVENT_STREAM_TRAILER_LENGTH); + + uint32_t message_crc = 0; + aws_byte_cursor_read_be32(&read_cur, &message_crc); + + return message_crc; } const uint8_t *aws_event_stream_message_buffer(const struct aws_event_stream_message *message) { - return message->message_buffer; + AWS_FATAL_PRECONDITION(message); + return message->message_buffer.buffer; } #define DEBUG_STR_PRELUDE_TOTAL_LEN "\"total_length\": " @@ -474,6 +536,9 @@ const uint8_t *aws_event_stream_message_buffer(const struct aws_event_stream_mes #define DEBUG_STR_HEADER_TYPE "\"type\": " int aws_event_stream_message_to_debug_str(FILE *fd, const struct aws_event_stream_message *message) { + AWS_FATAL_PRECONDITION(fd); + AWS_FATAL_PRECONDITION(message); + struct aws_array_list headers; aws_event_stream_headers_list_init(&headers, message->alloc); aws_event_stream_message_headers(message, &headers); @@ -527,9 +592,6 @@ int aws_event_stream_message_to_debug_str(FILE *fd, const struct aws_event_strea size_t buffer_len = 0; aws_base64_compute_encoded_len(header->header_value_len, &buffer_len); char *encoded_buffer = (char *)aws_mem_acquire(message->alloc, buffer_len); - if (!encoded_buffer) { - return aws_raise_error(AWS_ERROR_OOM); - } struct aws_byte_buf encode_output = aws_byte_buf_from_array((uint8_t *)encoded_buffer, buffer_len); @@ -565,10 +627,6 @@ int aws_event_stream_message_to_debug_str(FILE *fd, const struct aws_event_strea aws_base64_compute_encoded_len(payload_len, &encoded_len); char *encoded_payload = (char *)aws_mem_acquire(message->alloc, encoded_len); - if (!encoded_payload) { - return aws_raise_error(AWS_ERROR_OOM); - } - struct aws_byte_cursor payload_buffer = aws_byte_cursor_from_array(payload, payload_len); struct aws_byte_buf encoded_payload_buffer = aws_byte_buf_from_array((uint8_t *)encoded_payload, encoded_len); @@ -580,13 +638,15 @@ int aws_event_stream_message_to_debug_str(FILE *fd, const struct aws_event_strea } int aws_event_stream_headers_list_init(struct aws_array_list *headers, struct aws_allocator *allocator) { - AWS_ASSERT(headers); - AWS_ASSERT(allocator); + AWS_FATAL_PRECONDITION(headers); + AWS_FATAL_PRECONDITION(allocator); return aws_array_list_init_dynamic(headers, allocator, 4, sizeof(struct aws_event_stream_header_value_pair)); } void aws_event_stream_headers_list_cleanup(struct aws_array_list *headers) { + AWS_FATAL_PRECONDITION(headers); + if (AWS_UNLIKELY(!headers || !aws_array_list_is_valid(headers))) { return; } @@ -616,10 +676,6 @@ static int s_add_variable_len_header( if (copy) { header->header_value.variable_len_val = aws_mem_acquire(headers->alloc, value_len); - if (!header->header_value.variable_len_val) { - return aws_raise_error(AWS_ERROR_OOM); - } - header->value_owned = 1; memcpy((void *)header->header_value.variable_len_val, (void *)value, value_len); } else { @@ -644,6 +700,10 @@ int aws_event_stream_add_string_header( const char *value, uint16_t value_len, int8_t copy) { + AWS_FATAL_PRECONDITION(headers); + AWS_RETURN_ERROR_IF( + name_len <= AWS_EVENT_STREAM_HEADER_NAME_LEN_MAX, AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN); + AWS_RETURN_ERROR_IF(value_len <= INT16_MAX, AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN); struct aws_event_stream_header_value_pair header = {.header_name_len = name_len, .header_value_len = value_len, .value_owned = copy, @@ -655,8 +715,8 @@ int aws_event_stream_add_string_header( struct aws_event_stream_header_value_pair aws_event_stream_create_string_header( struct aws_byte_cursor name, struct aws_byte_cursor value) { - AWS_PRECONDITION(name.len < INT8_MAX); - AWS_PRECONDITION(value.len < INT16_MAX); + AWS_FATAL_PRECONDITION(name.len <= AWS_EVENT_STREAM_HEADER_NAME_LEN_MAX); + AWS_FATAL_PRECONDITION(value.len <= INT16_MAX); struct aws_event_stream_header_value_pair header = { .header_value_type = AWS_EVENT_STREAM_HEADER_STRING, @@ -674,7 +734,7 @@ struct aws_event_stream_header_value_pair aws_event_stream_create_string_header( struct aws_event_stream_header_value_pair aws_event_stream_create_int32_header( struct aws_byte_cursor name, int32_t value) { - AWS_PRECONDITION(name.len < INT8_MAX); + AWS_FATAL_PRECONDITION(name.len <= AWS_EVENT_STREAM_HEADER_NAME_LEN_MAX); struct aws_event_stream_header_value_pair header = { .header_value_type = AWS_EVENT_STREAM_HEADER_INT32, @@ -690,6 +750,11 @@ struct aws_event_stream_header_value_pair aws_event_stream_create_int32_header( } int aws_event_stream_add_byte_header(struct aws_array_list *headers, const char *name, uint8_t name_len, int8_t value) { + AWS_FATAL_PRECONDITION(headers); + AWS_FATAL_PRECONDITION(name); + AWS_RETURN_ERROR_IF( + name_len <= AWS_EVENT_STREAM_HEADER_NAME_LEN_MAX, AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN); + struct aws_event_stream_header_value_pair header = {.header_name_len = name_len, .header_value_len = 1, .value_owned = 0, @@ -702,6 +767,11 @@ int aws_event_stream_add_byte_header(struct aws_array_list *headers, const char } int aws_event_stream_add_bool_header(struct aws_array_list *headers, const char *name, uint8_t name_len, int8_t value) { + AWS_FATAL_PRECONDITION(headers); + AWS_FATAL_PRECONDITION(name); + AWS_RETURN_ERROR_IF( + name_len <= AWS_EVENT_STREAM_HEADER_NAME_LEN_MAX, AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN); + struct aws_event_stream_header_value_pair header = { .header_name_len = name_len, .header_value_len = 0, @@ -719,6 +789,11 @@ int aws_event_stream_add_int16_header( const char *name, uint8_t name_len, int16_t value) { + AWS_FATAL_PRECONDITION(headers); + AWS_FATAL_PRECONDITION(name); + AWS_RETURN_ERROR_IF( + name_len <= AWS_EVENT_STREAM_HEADER_NAME_LEN_MAX, AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN); + struct aws_event_stream_header_value_pair header = { .header_name_len = name_len, .header_value_len = sizeof(value), @@ -737,6 +812,11 @@ int aws_event_stream_add_int32_header( const char *name, uint8_t name_len, int32_t value) { + AWS_FATAL_PRECONDITION(headers); + AWS_FATAL_PRECONDITION(name); + AWS_RETURN_ERROR_IF( + name_len <= AWS_EVENT_STREAM_HEADER_NAME_LEN_MAX, AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN); + struct aws_event_stream_header_value_pair header = { .header_name_len = name_len, .header_value_len = sizeof(value), @@ -755,6 +835,12 @@ int aws_event_stream_add_int64_header( const char *name, uint8_t name_len, int64_t value) { + + AWS_FATAL_PRECONDITION(headers); + AWS_FATAL_PRECONDITION(name); + AWS_RETURN_ERROR_IF( + name_len <= AWS_EVENT_STREAM_HEADER_NAME_LEN_MAX, AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN); + struct aws_event_stream_header_value_pair header = { .header_name_len = name_len, .header_value_len = sizeof(value), @@ -775,6 +861,12 @@ int aws_event_stream_add_bytebuf_header( uint8_t *value, uint16_t value_len, int8_t copy) { + AWS_FATAL_PRECONDITION(headers); + AWS_FATAL_PRECONDITION(name); + AWS_RETURN_ERROR_IF( + name_len <= AWS_EVENT_STREAM_HEADER_NAME_LEN_MAX, AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN); + AWS_RETURN_ERROR_IF(value_len <= INT16_MAX, AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN); + struct aws_event_stream_header_value_pair header = {.header_name_len = name_len, .header_value_len = value_len, .value_owned = copy, @@ -788,6 +880,11 @@ int aws_event_stream_add_timestamp_header( const char *name, uint8_t name_len, int64_t value) { + AWS_FATAL_PRECONDITION(headers); + AWS_FATAL_PRECONDITION(name); + AWS_RETURN_ERROR_IF( + name_len <= AWS_EVENT_STREAM_HEADER_NAME_LEN_MAX, AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN); + struct aws_event_stream_header_value_pair header = { .header_name_len = name_len, .header_value_len = sizeof(uint64_t), @@ -806,60 +903,83 @@ int aws_event_stream_add_uuid_header( const char *name, uint8_t name_len, const uint8_t *value) { + AWS_FATAL_PRECONDITION(headers); + AWS_FATAL_PRECONDITION(name); + AWS_RETURN_ERROR_IF( + name_len <= AWS_EVENT_STREAM_HEADER_NAME_LEN_MAX, AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN); + struct aws_event_stream_header_value_pair header = { .header_name_len = name_len, - .header_value_len = 16, + .header_value_len = UUID_LEN, .value_owned = 0, .header_value_type = AWS_EVENT_STREAM_HEADER_UUID, }; memcpy((void *)header.header_name, (void *)name, (size_t)name_len); - memcpy((void *)header.header_value.static_val, value, 16); + memcpy((void *)header.header_value.static_val, value, UUID_LEN); return aws_array_list_push_back(headers, (void *)&header); } struct aws_byte_buf aws_event_stream_header_name(struct aws_event_stream_header_value_pair *header) { + AWS_FATAL_PRECONDITION(header); + return aws_byte_buf_from_array((uint8_t *)header->header_name, header->header_name_len); } int8_t aws_event_stream_header_value_as_byte(struct aws_event_stream_header_value_pair *header) { + AWS_FATAL_PRECONDITION(header); + return (int8_t)header->header_value.static_val[0]; } struct aws_byte_buf aws_event_stream_header_value_as_string(struct aws_event_stream_header_value_pair *header) { + AWS_FATAL_PRECONDITION(header); + return aws_event_stream_header_value_as_bytebuf(header); } int8_t aws_event_stream_header_value_as_bool(struct aws_event_stream_header_value_pair *header) { + AWS_FATAL_PRECONDITION(header); + return header->header_value_type == AWS_EVENT_STREAM_HEADER_BOOL_TRUE ? (int8_t)1 : (int8_t)0; } int16_t aws_event_stream_header_value_as_int16(struct aws_event_stream_header_value_pair *header) { + AWS_FATAL_PRECONDITION(header); + return (int16_t)aws_read_u16(header->header_value.static_val); } int32_t aws_event_stream_header_value_as_int32(struct aws_event_stream_header_value_pair *header) { + AWS_FATAL_PRECONDITION(header); + return (int32_t)aws_read_u32(header->header_value.static_val); } int64_t aws_event_stream_header_value_as_int64(struct aws_event_stream_header_value_pair *header) { + AWS_FATAL_PRECONDITION(header); return (int64_t)aws_read_u64(header->header_value.static_val); } struct aws_byte_buf aws_event_stream_header_value_as_bytebuf(struct aws_event_stream_header_value_pair *header) { + AWS_FATAL_PRECONDITION(header); return aws_byte_buf_from_array(header->header_value.variable_len_val, header->header_value_len); } int64_t aws_event_stream_header_value_as_timestamp(struct aws_event_stream_header_value_pair *header) { + AWS_FATAL_PRECONDITION(header); return aws_event_stream_header_value_as_int64(header); } struct aws_byte_buf aws_event_stream_header_value_as_uuid(struct aws_event_stream_header_value_pair *header) { - return aws_byte_buf_from_array(header->header_value.static_val, 16); + AWS_FATAL_PRECONDITION(header); + return aws_byte_buf_from_array(header->header_value.static_val, UUID_LEN); } uint16_t aws_event_stream_header_value_length(struct aws_event_stream_header_value_pair *header) { + AWS_FATAL_PRECONDITION(header); + return header->header_value_len; } @@ -920,10 +1040,6 @@ static int s_read_header_value( current_header->header_value.variable_len_val = aws_mem_acquire(decoder->alloc, decoder->current_header.header_value_len); - if (!current_header->header_value.variable_len_val) { - return aws_raise_error(AWS_ERROR_OOM); - } - current_header->value_owned = 1; } } diff --git a/source/event_stream_rpc_client.c b/source/event_stream_rpc_client.c index 5fb493f..443d9ad 100644 --- a/source/event_stream_rpc_client.c +++ b/source/event_stream_rpc_client.c @@ -543,7 +543,18 @@ static int s_send_protocol_message( args->flush_fn = flush_fn; - size_t headers_count = operation_name ? message_args->headers_count + 4 : message_args->headers_count + 3; + size_t headers_count = 0; + + if (operation_name) { + if (aws_add_size_checked(message_args->headers_count, 4, &headers_count)) { + return AWS_OP_ERR; + } + } else { + if (aws_add_size_checked(message_args->headers_count, 3, &headers_count)) { + return AWS_OP_ERR; + } + } + struct aws_array_list headers_list; AWS_ZERO_STRUCT(headers_list); diff --git a/source/event_stream_rpc_server.c b/source/event_stream_rpc_server.c index 6460383..2a91c6f 100644 --- a/source/event_stream_rpc_server.c +++ b/source/event_stream_rpc_server.c @@ -596,7 +596,17 @@ static int s_send_protocol_message( args->flush_fn = flush_fn; - size_t headers_count = message_args->headers_count + 3; + size_t headers_count = 0; + + if (aws_add_size_checked(message_args->headers_count, 3, &headers_count)) { + AWS_LOGF_ERROR( + AWS_LS_EVENT_STREAM_RPC_SERVER, + "id=%p: integer overflow detected when using headers_count %zu", + (void *)connection, + message_args->headers_count); + goto args_allocated_before_failure; + } + struct aws_array_list headers_list; AWS_ZERO_STRUCT(headers_list); diff --git a/tests/event_stream_rpc_client_connection_test.c b/tests/event_stream_rpc_client_connection_test.c index 3b5f886..91278ce 100644 --- a/tests/event_stream_rpc_client_connection_test.c +++ b/tests/event_stream_rpc_client_connection_test.c @@ -27,8 +27,10 @@ struct client_test_data { struct aws_byte_buf received_payload; struct aws_event_stream_rpc_server_continuation_token *server_token; struct aws_byte_buf last_seen_operation_name; - bool message_sent; - bool message_received; + bool client_message_sent; + bool client_message_received; + bool server_message_sent; + bool server_message_received; bool client_token_closed; bool server_token_closed; }; @@ -407,14 +409,31 @@ static void s_rpc_client_message_flush(int error_code, void *user_data) { struct client_test_data *client_test_data = user_data; aws_mutex_lock(&client_test_data->sync_lock); - client_test_data->message_sent = true; + client_test_data->client_message_sent = true; + aws_condition_variable_notify_one(&client_test_data->sync_cvar); + /* make these pessimistic to prevent a cleanup race. */ aws_mutex_unlock(&client_test_data->sync_lock); +} + +static void s_rpc_server_message_flush(int error_code, void *user_data) { + (void)error_code; + + struct client_test_data *client_test_data = user_data; + aws_mutex_lock(&client_test_data->sync_lock); + client_test_data->server_message_sent = true; aws_condition_variable_notify_one(&client_test_data->sync_cvar); + /* make these pessimistic to prevent a cleanup race. */ + aws_mutex_unlock(&client_test_data->sync_lock); } static bool s_rpc_client_message_transmission_completed_pred(void *arg) { struct client_test_data *client_test_data = arg; - return client_test_data->message_sent && client_test_data->message_received; + return client_test_data->client_message_sent && client_test_data->server_message_received; +} + +static bool s_rpc_server_message_transmission_completed_pred(void *arg) { + struct client_test_data *client_test_data = arg; + return client_test_data->server_message_sent && client_test_data->client_message_received; } static void s_rpc_server_connection_protocol_message( @@ -425,7 +444,7 @@ static void s_rpc_server_connection_protocol_message( struct client_test_data *client_test_data = user_data; aws_mutex_lock(&client_test_data->sync_lock); - client_test_data->message_received = true; + client_test_data->server_message_received = true; client_test_data->received_message_type = message_args->message_type; aws_byte_buf_init_copy(&client_test_data->received_payload, client_test_data->allocator, message_args->payload); aws_mutex_unlock(&client_test_data->sync_lock); @@ -440,7 +459,7 @@ static void s_rpc_client_connection_protocol_message( struct client_test_data *client_test_data = user_data; aws_mutex_lock(&client_test_data->sync_lock); - client_test_data->message_received = true; + client_test_data->client_message_received = true; client_test_data->received_message_type = message_args->message_type; client_test_data->received_message_flags = message_args->message_flags; aws_byte_buf_init_copy(&client_test_data->received_payload, client_test_data->allocator, message_args->payload); @@ -489,19 +508,21 @@ static int s_test_event_stream_rpc_client_connection_connect(struct aws_allocato aws_byte_buf_clean_up(&client_test_data.received_payload); client_test_data.received_message_type = 0; - client_test_data.message_received = false; - client_test_data.message_sent = false; + client_test_data.client_message_received = false; + client_test_data.client_message_sent = false; + client_test_data.server_message_received = false; + client_test_data.server_message_sent = false; connect_args.message_type = AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_CONNECT_ACK; connect_args.message_flags = AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_CONNECTION_ACCEPTED; ASSERT_SUCCESS(aws_event_stream_rpc_server_connection_send_protocol_message( - test_data->server_connection, &connect_args, s_rpc_client_message_flush, &client_test_data)); + test_data->server_connection, &connect_args, s_rpc_server_message_flush, &client_test_data)); aws_condition_variable_wait_pred( &client_test_data.sync_cvar, &client_test_data.sync_lock, - s_rpc_client_message_transmission_completed_pred, + s_rpc_server_message_transmission_completed_pred, &client_test_data); ASSERT_INT_EQUALS(AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_CONNECT_ACK, client_test_data.received_message_type); @@ -603,19 +624,21 @@ static int s_test_event_stream_rpc_client_connection_protocol_message(struct aws aws_byte_buf_clean_up(&client_test_data.received_payload); client_test_data.received_message_type = 0; - client_test_data.message_received = false; - client_test_data.message_sent = false; + client_test_data.client_message_received = false; + client_test_data.client_message_sent = false; + client_test_data.server_message_received = false; + client_test_data.server_message_sent = false; connect_args.message_type = AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_CONNECT_ACK; connect_args.message_flags = AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_CONNECTION_ACCEPTED; ASSERT_SUCCESS(aws_event_stream_rpc_server_connection_send_protocol_message( - test_data->server_connection, &connect_args, s_rpc_client_message_flush, &client_test_data)); + test_data->server_connection, &connect_args, s_rpc_server_message_flush, &client_test_data)); aws_condition_variable_wait_pred( &client_test_data.sync_cvar, &client_test_data.sync_lock, - s_rpc_client_message_transmission_completed_pred, + s_rpc_server_message_transmission_completed_pred, &client_test_data); ASSERT_INT_EQUALS(AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_CONNECT_ACK, client_test_data.received_message_type); @@ -627,8 +650,10 @@ static int s_test_event_stream_rpc_client_connection_protocol_message(struct aws aws_byte_buf_clean_up(&client_test_data.received_payload); client_test_data.received_message_type = 0; - client_test_data.message_received = false; - client_test_data.message_sent = false; + client_test_data.client_message_received = false; + client_test_data.client_message_sent = false; + client_test_data.server_message_received = false; + client_test_data.server_message_sent = false; struct aws_byte_buf ping_payload = aws_byte_buf_from_c_str("{ \"message\": \"hello device that will further isolate humans from each other " @@ -661,16 +686,18 @@ static int s_test_event_stream_rpc_client_connection_protocol_message(struct aws ping_args.message_type = AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_PING_RESPONSE; client_test_data.received_message_type = 0; - client_test_data.message_received = false; - client_test_data.message_sent = false; + client_test_data.client_message_received = false; + client_test_data.client_message_sent = false; + client_test_data.server_message_received = false; + client_test_data.server_message_sent = false; ASSERT_SUCCESS(aws_event_stream_rpc_server_connection_send_protocol_message( - test_data->server_connection, &ping_args, s_rpc_client_message_flush, &client_test_data)); + test_data->server_connection, &ping_args, s_rpc_server_message_flush, &client_test_data)); aws_condition_variable_wait_pred( &client_test_data.sync_cvar, &client_test_data.sync_lock, - s_rpc_client_message_transmission_completed_pred, + s_rpc_server_message_transmission_completed_pred, &client_test_data); ASSERT_INT_EQUALS(AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_PING_RESPONSE, client_test_data.received_message_type); @@ -706,7 +733,7 @@ static void s_rpc_client_stream_continuation( struct client_test_data *client_test_data = user_data; aws_mutex_lock(&client_test_data->sync_lock); - client_test_data->message_received = true; + client_test_data->client_message_received = true; client_test_data->received_message_type = message_args->message_type; aws_byte_buf_init_copy(&client_test_data->received_payload, client_test_data->allocator, message_args->payload); aws_mutex_unlock(&client_test_data->sync_lock); @@ -755,7 +782,7 @@ static void s_rpc_server_stream_continuation( struct client_test_data *client_test_data = user_data; aws_mutex_lock(&client_test_data->sync_lock); - client_test_data->message_received = true; + client_test_data->server_message_received = true; client_test_data->received_message_type = message_args->message_type; aws_byte_buf_init_copy(&client_test_data->received_payload, client_test_data->allocator, message_args->payload); aws_mutex_unlock(&client_test_data->sync_lock); @@ -818,19 +845,21 @@ static int s_test_event_stream_rpc_client_connection_continuation_flow(struct aw aws_byte_buf_clean_up(&client_test_data.received_payload); client_test_data.received_message_type = 0; - client_test_data.message_received = false; - client_test_data.message_sent = false; + client_test_data.client_message_received = false; + client_test_data.client_message_sent = false; + client_test_data.server_message_received = false; + client_test_data.server_message_sent = false; connect_args.message_type = AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_CONNECT_ACK; connect_args.message_flags = AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_CONNECTION_ACCEPTED; ASSERT_SUCCESS(aws_event_stream_rpc_server_connection_send_protocol_message( - test_data->server_connection, &connect_args, s_rpc_client_message_flush, &client_test_data)); + test_data->server_connection, &connect_args, s_rpc_server_message_flush, &client_test_data)); aws_condition_variable_wait_pred( &client_test_data.sync_cvar, &client_test_data.sync_lock, - s_rpc_client_message_transmission_completed_pred, + s_rpc_server_message_transmission_completed_pred, &client_test_data); ASSERT_INT_EQUALS(AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_CONNECT_ACK, client_test_data.received_message_type); @@ -842,8 +871,10 @@ static int s_test_event_stream_rpc_client_connection_continuation_flow(struct aw aws_byte_buf_clean_up(&client_test_data.received_payload); client_test_data.received_message_type = 0; - client_test_data.message_received = false; - client_test_data.message_sent = false; + client_test_data.client_message_received = false; + client_test_data.client_message_sent = false; + client_test_data.server_message_received = false; + client_test_data.server_message_sent = false; struct aws_event_stream_rpc_client_stream_continuation_options continuation_options = { .user_data = &client_test_data, @@ -890,19 +921,21 @@ static int s_test_event_stream_rpc_client_connection_continuation_flow(struct aw aws_byte_buf_clean_up(&client_test_data.last_seen_operation_name); client_test_data.received_message_type = 0; - client_test_data.message_received = false; - client_test_data.message_sent = false; + client_test_data.client_message_received = false; + client_test_data.client_message_sent = false; + client_test_data.server_message_received = false; + client_test_data.server_message_sent = false; operation_args.message_flags = AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_TERMINATE_STREAM; operation_args.message_type = AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_APPLICATION_ERROR; ASSERT_SUCCESS(aws_event_stream_rpc_server_continuation_send_message( - client_test_data.server_token, &operation_args, s_rpc_client_message_flush, &client_test_data)); + client_test_data.server_token, &operation_args, s_rpc_server_message_flush, &client_test_data)); aws_condition_variable_wait_pred( &client_test_data.sync_cvar, &client_test_data.sync_lock, - s_rpc_client_message_transmission_completed_pred, + s_rpc_server_message_transmission_completed_pred, &client_test_data); aws_condition_variable_wait_pred( @@ -986,19 +1019,21 @@ static int s_test_event_stream_rpc_client_connection_unactivated_continuation_fa aws_byte_buf_clean_up(&client_test_data.received_payload); client_test_data.received_message_type = 0; - client_test_data.message_received = false; - client_test_data.message_sent = false; + client_test_data.client_message_received = false; + client_test_data.client_message_sent = false; + client_test_data.server_message_received = false; + client_test_data.server_message_sent = false; connect_args.message_type = AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_CONNECT_ACK; connect_args.message_flags = AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_CONNECTION_ACCEPTED; ASSERT_SUCCESS(aws_event_stream_rpc_server_connection_send_protocol_message( - test_data->server_connection, &connect_args, s_rpc_client_message_flush, &client_test_data)); + test_data->server_connection, &connect_args, s_rpc_server_message_flush, &client_test_data)); aws_condition_variable_wait_pred( &client_test_data.sync_cvar, &client_test_data.sync_lock, - s_rpc_client_message_transmission_completed_pred, + s_rpc_server_message_transmission_completed_pred, &client_test_data); ASSERT_INT_EQUALS(AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_CONNECT_ACK, client_test_data.received_message_type); @@ -1010,8 +1045,10 @@ static int s_test_event_stream_rpc_client_connection_unactivated_continuation_fa aws_byte_buf_clean_up(&client_test_data.received_payload); client_test_data.received_message_type = 0; - client_test_data.message_received = false; - client_test_data.message_sent = false; + client_test_data.client_message_received = false; + client_test_data.client_message_sent = false; + client_test_data.server_message_received = false; + client_test_data.server_message_sent = false; struct aws_event_stream_rpc_client_stream_continuation_options continuation_options = { .user_data = &client_test_data, @@ -1106,20 +1143,22 @@ static int s_test_event_stream_rpc_client_connection_continuation_send_message_o /* server sends CONNECT_ACK */ client_test_data.received_message_type = 0; - client_test_data.message_received = false; - client_test_data.message_sent = false; + client_test_data.client_message_received = false; + client_test_data.client_message_sent = false; + client_test_data.server_message_received = false; + client_test_data.server_message_sent = false; connect_args.message_type = AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_CONNECT_ACK; connect_args.message_flags = AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_CONNECTION_ACCEPTED; ASSERT_SUCCESS(aws_event_stream_rpc_server_connection_send_protocol_message( - test_data->server_connection, &connect_args, s_rpc_client_message_flush, &client_test_data)); + test_data->server_connection, &connect_args, s_rpc_server_message_flush, &client_test_data)); /* ...wait until sent and received... */ aws_condition_variable_wait_pred( &client_test_data.sync_cvar, &client_test_data.sync_lock, - s_rpc_client_message_transmission_completed_pred, + s_rpc_server_message_transmission_completed_pred, &client_test_data); ASSERT_INT_EQUALS(AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_CONNECT_ACK, client_test_data.received_message_type); @@ -1132,8 +1171,10 @@ static int s_test_event_stream_rpc_client_connection_continuation_send_message_o /* client sends message creating new stream */ client_test_data.received_message_type = 0; - client_test_data.message_received = false; - client_test_data.message_sent = false; + client_test_data.client_message_received = false; + client_test_data.client_message_sent = false; + client_test_data.server_message_received = false; + client_test_data.server_message_sent = false; struct aws_event_stream_rpc_client_stream_continuation_options continuation_options = { .user_data = &client_test_data, @@ -1182,20 +1223,22 @@ static int s_test_event_stream_rpc_client_connection_continuation_send_message_o /* server sends response with TERMINATE_STREAM flag set */ client_test_data.received_message_type = 0; - client_test_data.message_received = false; - client_test_data.message_sent = false; + client_test_data.client_message_received = false; + client_test_data.client_message_sent = false; + client_test_data.server_message_received = false; + client_test_data.server_message_sent = false; operation_args.message_flags = AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_TERMINATE_STREAM; operation_args.message_type = AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_APPLICATION_ERROR; ASSERT_SUCCESS(aws_event_stream_rpc_server_continuation_send_message( - client_test_data.server_token, &operation_args, s_rpc_client_message_flush, &client_test_data)); + client_test_data.server_token, &operation_args, s_rpc_server_message_flush, &client_test_data)); /* ...wait until sent and received... */ aws_condition_variable_wait_pred( &client_test_data.sync_cvar, &client_test_data.sync_lock, - s_rpc_client_message_transmission_completed_pred, + s_rpc_server_message_transmission_completed_pred, &client_test_data); /* ...wait until client stream closed... */ @@ -1284,19 +1327,21 @@ static int s_test_event_stream_rpc_client_connection_continuation_duplicated_act aws_byte_buf_clean_up(&client_test_data->received_payload); client_test_data->received_message_type = 0; - client_test_data->message_received = false; - client_test_data->message_sent = false; + client_test_data->client_message_received = false; + client_test_data->client_message_sent = false; + client_test_data->server_message_received = false; + client_test_data->server_message_sent = false; connect_args.message_type = AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_CONNECT_ACK; connect_args.message_flags = AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_CONNECTION_ACCEPTED; ASSERT_SUCCESS(aws_event_stream_rpc_server_connection_send_protocol_message( - test_data->server_connection, &connect_args, s_rpc_client_message_flush, client_test_data)); + test_data->server_connection, &connect_args, s_rpc_server_message_flush, client_test_data)); aws_condition_variable_wait_pred( &client_test_data->sync_cvar, &client_test_data->sync_lock, - s_rpc_client_message_transmission_completed_pred, + s_rpc_server_message_transmission_completed_pred, client_test_data); ASSERT_INT_EQUALS(AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_CONNECT_ACK, client_test_data->received_message_type); @@ -1308,8 +1353,10 @@ static int s_test_event_stream_rpc_client_connection_continuation_duplicated_act aws_byte_buf_clean_up(&client_test_data->received_payload); client_test_data->received_message_type = 0; - client_test_data->message_received = false; - client_test_data->message_sent = false; + client_test_data->client_message_received = false; + client_test_data->client_message_sent = false; + client_test_data->server_message_received = false; + client_test_data->server_message_sent = false; struct aws_event_stream_rpc_client_stream_continuation_options continuation_options = { .user_data = client_test_data, diff --git a/tests/message_deserializer_test.c b/tests/message_deserializer_test.c index 42d274b..a57588b 100644 --- a/tests/message_deserializer_test.c +++ b/tests/message_deserializer_test.c @@ -15,7 +15,8 @@ static int s_test_outgoing_no_op_valid_fn(struct aws_allocator *allocator, void struct aws_event_stream_message message; struct aws_byte_buf test_buf = aws_byte_buf_from_array(test_data, sizeof(test_data)); ASSERT_SUCCESS( - aws_event_stream_message_from_buffer(&message, NULL, &test_buf), "Message validation should have succeeded"); + aws_event_stream_message_from_buffer(&message, allocator, &test_buf), + "Message validation should have succeeded"); ASSERT_INT_EQUALS( 0x00000010, aws_event_stream_message_total_length(&message), "Message length should have been 0x10"); @@ -43,7 +44,8 @@ static int s_test_outgoing_application_data_no_headers_valid_fn(struct aws_alloc struct aws_byte_buf test_buf = aws_byte_buf_from_array(test_data, sizeof(test_data)); ASSERT_SUCCESS( - aws_event_stream_message_from_buffer(&message, NULL, &test_buf), "Message validation should have succeeded"); + aws_event_stream_message_from_buffer(&message, allocator, &test_buf), + "Message validation should have succeeded"); ASSERT_INT_EQUALS( 0x0000001D, aws_event_stream_message_total_length(&message), "Message length should have been 0x0000001D");