From 8be980e7890445c3136c8db1429f8b2a16b26ae2 Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Fri, 6 Mar 2020 16:40:03 +0200 Subject: [PATCH] provides dedicated an operator per request Signed-off-by: Oleh Dokuka --- .../core/FireAndForgetRequesterMono.java | 194 +++ .../FireAndForgetResponderSubscriber.java | 123 ++ .../io/rsocket/core/FragmentationUtils.java | 219 ++++ .../FrameHandler.java} | 24 +- .../core/MetadataPushRequesterMono.java | 184 +++ .../core/MetadataPushResponderSubscriber.java | 45 + .../rsocket/core/PayloadValidationUtils.java | 45 +- .../io/rsocket/core/RSocketConnector.java | 26 +- .../io/rsocket/core/RSocketRequester.java | 667 ++-------- .../io/rsocket/core/RSocketResponder.java | 476 ++----- .../java/io/rsocket/core/RSocketServer.java | 21 +- .../java/io/rsocket/core/ReassemblyUtils.java | 237 ++++ .../core/RequestChannelRequesterFlux.java | 549 ++++++++ .../RequestChannelResponderSubscriber.java | 726 +++++++++++ .../java/io/rsocket/core/RequestOperator.java | 189 --- .../core/RequestResponseRequesterMono.java | 290 +++++ .../RequestResponseResponderSubscriber.java | 244 ++++ .../core/RequestStreamRequesterFlux.java | 313 +++++ .../RequestStreamResponderSubscriber.java | 287 +++++ .../rsocket/core/RequesterFrameHandler.java | 43 + .../core/RequesterResponderSupport.java | 128 ++ .../rsocket/core/ResponderFrameHandler.java | 38 + .../main/java/io/rsocket/core/SendUtils.java | 326 +++++ .../main/java/io/rsocket/core/StateUtils.java | 385 ++++++ .../FragmentationDuplexConnection.java | 112 -- .../fragmentation/FrameFragmenter.java | 235 ---- .../fragmentation/FrameReassembler.java | 342 ----- .../ReassemblyDuplexConnection.java | 89 -- .../SynchronizedIntObjectHashMap.java | 748 ----------- .../java/io/rsocket/util/ByteBufPayload.java | 2 +- .../src/test/java/io/rsocket/FrameAssert.java | 336 +++++ .../test/java/io/rsocket/PayloadAssert.java | 180 +++ .../io/rsocket/core/AbstractSocketRule.java | 8 +- .../core/DefaultRSocketClientTests.java | 51 +- .../core/FireAndForgetRequesterMonoTest.java | 404 ++++++ .../java/io/rsocket/core/KeepAliveTest.java | 9 +- .../core/PayloadValidationUtilsTest.java | 60 +- .../io/rsocket/core/RSocketLeaseTest.java | 55 +- .../io/rsocket/core/RSocketReconnectTest.java | 13 - .../core/RSocketRequesterSubscribersTest.java | 63 +- .../io/rsocket/core/RSocketRequesterTest.java | 155 ++- .../io/rsocket/core/RSocketResponderTest.java | 11 +- .../java/io/rsocket/core/RSocketTest.java | 10 +- .../core/RequestChannelRequesterFluxTest.java | 756 +++++++++++ ...RequestChannelResponderSubscriberTest.java | 688 ++++++++++ .../RequestResponseRequesterMonoTest.java | 695 ++++++++++ .../core/RequestStreamRequesterFluxTest.java | 1146 +++++++++++++++++ .../core/RequesterOperatorsRacingTest.java | 702 ++++++++++ .../core/ResponderOperatorsCommonTest.java | 413 ++++++ .../io/rsocket/core/SetupRejectionTest.java | 8 +- .../java/io/rsocket/core/ShouldHaveFlag.java | 98 ++ .../io/rsocket/core/ShouldNotHaveFlag.java | 73 ++ .../java/io/rsocket/core/StateAssert.java | 161 +++ .../io/rsocket/core/StreamIdSupplierTest.java | 14 +- .../core/TestRequesterResponderSupport.java | 167 +++ .../io/rsocket/exceptions/ExceptionsTest.java | 1 + .../FragmentationDuplexConnectionTest.java | 112 -- .../FragmentationIntegrationTest.java | 56 - .../fragmentation/FrameFragmenterTest.java | 350 ----- .../fragmentation/FrameReassemblerTest.java | 526 -------- .../ReassembleDuplexConnectionTest.java | 334 ----- .../rsocket/frame/ByteBufRepresentation.java | 17 +- .../test/FragmentationTransportTest.java | 463 +++++++ .../java/io/rsocket/test/TestRSocket.java | 14 +- .../java/io/rsocket/test/TransportTest.java | 7 +- .../netty/TcpFragmentationTransportTest.java | 42 + .../netty/WebsocketSecureTransportTest.java | 2 +- 67 files changed, 11286 insertions(+), 4221 deletions(-) create mode 100644 rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java create mode 100644 rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java create mode 100644 rsocket-core/src/main/java/io/rsocket/core/FragmentationUtils.java rename rsocket-core/src/main/java/io/rsocket/{fragmentation/package-info.java => core/FrameHandler.java} (67%) create mode 100644 rsocket-core/src/main/java/io/rsocket/core/MetadataPushRequesterMono.java create mode 100644 rsocket-core/src/main/java/io/rsocket/core/MetadataPushResponderSubscriber.java create mode 100644 rsocket-core/src/main/java/io/rsocket/core/ReassemblyUtils.java create mode 100644 rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java create mode 100644 rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java delete mode 100644 rsocket-core/src/main/java/io/rsocket/core/RequestOperator.java create mode 100644 rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java create mode 100644 rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java create mode 100644 rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java create mode 100644 rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java create mode 100644 rsocket-core/src/main/java/io/rsocket/core/RequesterFrameHandler.java create mode 100644 rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java create mode 100644 rsocket-core/src/main/java/io/rsocket/core/ResponderFrameHandler.java create mode 100644 rsocket-core/src/main/java/io/rsocket/core/SendUtils.java create mode 100644 rsocket-core/src/main/java/io/rsocket/core/StateUtils.java delete mode 100644 rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java delete mode 100644 rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java delete mode 100644 rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java delete mode 100644 rsocket-core/src/main/java/io/rsocket/fragmentation/ReassemblyDuplexConnection.java delete mode 100644 rsocket-core/src/main/java/io/rsocket/internal/SynchronizedIntObjectHashMap.java create mode 100644 rsocket-core/src/test/java/io/rsocket/FrameAssert.java create mode 100755 rsocket-core/src/test/java/io/rsocket/PayloadAssert.java create mode 100644 rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java create mode 100644 rsocket-core/src/test/java/io/rsocket/core/RequestChannelRequesterFluxTest.java create mode 100644 rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java create mode 100644 rsocket-core/src/test/java/io/rsocket/core/RequestResponseRequesterMonoTest.java create mode 100644 rsocket-core/src/test/java/io/rsocket/core/RequestStreamRequesterFluxTest.java create mode 100644 rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java create mode 100755 rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java create mode 100644 rsocket-core/src/test/java/io/rsocket/core/ShouldHaveFlag.java create mode 100644 rsocket-core/src/test/java/io/rsocket/core/ShouldNotHaveFlag.java create mode 100644 rsocket-core/src/test/java/io/rsocket/core/StateAssert.java create mode 100644 rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java delete mode 100644 rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java delete mode 100644 rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java delete mode 100644 rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java delete mode 100644 rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java delete mode 100644 rsocket-core/src/test/java/io/rsocket/fragmentation/ReassembleDuplexConnectionTest.java create mode 100644 rsocket-test/src/main/java/io/rsocket/test/FragmentationTransportTest.java create mode 100644 rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpFragmentationTransportTest.java diff --git a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java new file mode 100644 index 000000000..3d7a3dfa7 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java @@ -0,0 +1,194 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.UnboundedProcessor; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class FireAndForgetRequesterMono extends Mono implements Subscription, Scannable { + + volatile long state; + + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(FireAndForgetRequesterMono.class, "state"); + + final Payload payload; + + final ByteBufAllocator allocator; + final int mtu; + final int maxFrameLength; + final RequesterResponderSupport requesterResponderSupport; + final UnboundedProcessor sendProcessor; + + FireAndForgetRequesterMono(Payload payload, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.requesterResponderSupport = requesterResponderSupport; + this.sendProcessor = requesterResponderSupport.getSendProcessor(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + Operators.error( + actual, new IllegalStateException("FireAndForgetMono allows only a single Subscriber")); + return; + } + + actual.onSubscribe(this); + + final Payload p = this.payload; + int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + p.release(); + actual.onError( + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + actual.onError(e); + return; + } + + final int streamId; + try { + streamId = this.requesterResponderSupport.getNextStreamId(); + } catch (Throwable t) { + lazyTerminate(STATE, this); + p.release(); + actual.onError(Exceptions.unwrap(t)); + return; + } + + try { + if (isTerminated(this.state)) { + p.release(); + return; + } + + sendReleasingPayload( + streamId, FrameType.REQUEST_FNF, mtu, p, this.sendProcessor, this.allocator, true); + } catch (Throwable e) { + lazyTerminate(STATE, this); + actual.onError(e); + return; + } + + lazyTerminate(STATE, this); + actual.onComplete(); + } + + @Override + public void request(long n) { + // no ops + } + + @Override + public void cancel() { + markTerminated(STATE, this); + } + + @Override + @Nullable + public Void block(Duration m) { + return block(); + } + + @Override + @Nullable + public Void block() { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + throw new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + } + + final Payload p = this.payload; + try { + if (!isValid(this.mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + p.release(); + throw new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + throw Exceptions.propagate(e); + } + + final int streamId; + try { + streamId = this.requesterResponderSupport.getNextStreamId(); + } catch (Throwable t) { + lazyTerminate(STATE, this); + p.release(); + throw Exceptions.propagate(t); + } + + try { + sendReleasingPayload( + streamId, + FrameType.REQUEST_FNF, + this.mtu, + this.payload, + this.sendProcessor, + this.allocator, + true); + } catch (Throwable e) { + lazyTerminate(STATE, this); + throw Exceptions.propagate(e); + } + + lazyTerminate(STATE, this); + return null; + } + + @Override + public Object scanUnsafe(Scannable.Attr key) { + return null; // no particular key to be represented, still useful in hooks + } + + @Override + @NonNull + public String stepName() { + return "source(FireAndForgetMono)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java new file mode 100644 index 000000000..8933d4089 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java @@ -0,0 +1,123 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.decoder.PayloadDecoder; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Mono; + +final class FireAndForgetResponderSubscriber + implements CoreSubscriber, ResponderFrameHandler { + + static final Logger logger = LoggerFactory.getLogger(FireAndForgetResponderSubscriber.class); + + static final FireAndForgetResponderSubscriber INSTANCE = new FireAndForgetResponderSubscriber(); + + final int streamId; + final ByteBufAllocator allocator; + final PayloadDecoder payloadDecoder; + final RequesterResponderSupport requesterResponderSupport; + final RSocket handler; + final int maxInboundPayloadSize; + + CompositeByteBuf frames; + + private FireAndForgetResponderSubscriber() { + this.streamId = 0; + this.allocator = null; + this.payloadDecoder = null; + this.maxInboundPayloadSize = 0; + this.requesterResponderSupport = null; + this.handler = null; + this.frames = null; + } + + FireAndForgetResponderSubscriber( + int streamId, + ByteBuf firstFrame, + RequesterResponderSupport requesterResponderSupport, + RSocket handler) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.handler = handler; + + this.frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), firstFrame, maxInboundPayloadSize); + } + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void voidVal) {} + + @Override + public void onError(Throwable t) { + logger.debug("Dropped Outbound error", t); + } + + @Override + public void onComplete() {} + + @Override + public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLastPayload) { + final CompositeByteBuf frames = + ReassemblyUtils.addFollowingFrame(this.frames, followingFrame, this.maxInboundPayloadSize); + + if (!hasFollows) { + this.requesterResponderSupport.remove(this.streamId, this); + this.frames = null; + + Payload payload; + try { + payload = this.payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frames); + logger.debug("Reassembly has failed", t); + return; + } + + Mono source = this.handler.fireAndForget(payload); + source.subscribe(this); + } + } + + @Override + public final void handleCancel() { + final CompositeByteBuf frames = this.frames; + if (frames != null) { + this.frames = null; + this.requesterResponderSupport.remove(this.streamId, this); + frames.release(); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/FragmentationUtils.java b/rsocket-core/src/main/java/io/rsocket/core/FragmentationUtils.java new file mode 100644 index 000000000..092a5312f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/FragmentationUtils.java @@ -0,0 +1,219 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import reactor.util.annotation.Nullable; + +class FragmentationUtils { + + static final int MIN_MTU_SIZE = 64; + + static final int FRAME_OFFSET = // 9 bytes in total + FrameLengthCodec.FRAME_LENGTH_SIZE // includes encoded frame length bytes size + + FrameHeaderCodec.size(); // includes encoded frame headers info bytes size + static final int FRAME_OFFSET_WITH_METADATA = // 12 bytes in total + FRAME_OFFSET + + FrameLengthCodec.FRAME_LENGTH_SIZE; // include encoded metadata length bytes size + + static final int FRAME_OFFSET_WITH_INITIAL_REQUEST_N = // 13 bytes in total + FRAME_OFFSET + Integer.BYTES; // includes extra space for initialRequestN bytes size + static final int FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N = // 16 bytes in total + FRAME_OFFSET_WITH_METADATA + + Integer.BYTES; // includes extra space for initialRequestN bytes size + + static boolean isFragmentable( + int mtu, ByteBuf data, @Nullable ByteBuf metadata, boolean hasInitialRequestN) { + if (mtu == 0) { + return false; + } + + if (metadata != null) { + int remaining = + mtu + - (hasInitialRequestN + ? FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N + : FRAME_OFFSET_WITH_METADATA); + + return (metadata.readableBytes() + data.readableBytes()) > remaining; + } else { + int remaining = + mtu - (hasInitialRequestN ? FRAME_OFFSET_WITH_INITIAL_REQUEST_N : FRAME_OFFSET); + + return data.readableBytes() > remaining; + } + } + + static ByteBuf encodeFollowsFragment( + ByteBufAllocator allocator, + int mtu, + int streamId, + boolean complete, + ByteBuf metadata, + ByteBuf data) { + // subtract the header bytes + frame length size + int remaining = mtu - FRAME_OFFSET; + + ByteBuf metadataFragment = null; + if (metadata.isReadable()) { + // subtract the metadata frame length + remaining -= FrameLengthCodec.FRAME_LENGTH_SIZE; + int r = Math.min(remaining, metadata.readableBytes()); + remaining -= r; + metadataFragment = metadata.readRetainedSlice(r); + } + + ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; + try { + if (remaining > 0 && data.isReadable()) { + int r = Math.min(remaining, data.readableBytes()); + dataFragment = data.readRetainedSlice(r); + } + } catch (IllegalReferenceCountException | NullPointerException e) { + if (metadataFragment != null) { + metadataFragment.release(); + } + throw e; + } + + boolean follows = data.isReadable() || metadata.isReadable(); + return PayloadFrameCodec.encode( + allocator, streamId, follows, (!follows && complete), true, metadataFragment, dataFragment); + } + + static ByteBuf encodeFirstFragment( + ByteBufAllocator allocator, + int mtu, + FrameType frameType, + int streamId, + boolean hasMetadata, + ByteBuf metadata, + ByteBuf data) { + // subtract the header bytes + frame length size + int remaining = mtu - FRAME_OFFSET; + + ByteBuf metadataFragment = hasMetadata ? Unpooled.EMPTY_BUFFER : null; + if (metadata.isReadable()) { + // subtract the metadata frame length + remaining -= FrameLengthCodec.FRAME_LENGTH_SIZE; + int r = Math.min(remaining, metadata.readableBytes()); + remaining -= r; + metadataFragment = metadata.readRetainedSlice(r); + } + + ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; + try { + if (remaining > 0 && data.isReadable()) { + int r = Math.min(remaining, data.readableBytes()); + dataFragment = data.readRetainedSlice(r); + } + } catch (IllegalReferenceCountException | NullPointerException e) { + if (metadataFragment != null) { + metadataFragment.release(); + } + throw e; + } + + switch (frameType) { + case REQUEST_FNF: + return RequestFireAndForgetFrameCodec.encode( + allocator, streamId, true, metadataFragment, dataFragment); + case REQUEST_RESPONSE: + return RequestResponseFrameCodec.encode( + allocator, streamId, true, metadataFragment, dataFragment); + // Payload and synthetic types from the responder side + case PAYLOAD: + return PayloadFrameCodec.encode( + allocator, streamId, true, false, false, metadataFragment, dataFragment); + case NEXT: + // see https://github.com/rsocket/rsocket/blob/master/Protocol.md#handling-the-unexpected + // point 7 + case NEXT_COMPLETE: + return PayloadFrameCodec.encode( + allocator, streamId, true, false, true, metadataFragment, dataFragment); + default: + throw new IllegalStateException("unsupported fragment type: " + frameType); + } + } + + static ByteBuf encodeFirstFragment( + ByteBufAllocator allocator, + int mtu, + long initialRequestN, + FrameType frameType, + int streamId, + ByteBuf metadata, + ByteBuf data) { + // subtract the header bytes + frame length bytes + initial requestN bytes + int remaining = mtu - FRAME_OFFSET_WITH_INITIAL_REQUEST_N; + + ByteBuf metadataFragment = null; + if (metadata.isReadable()) { + // subtract the metadata frame length + remaining -= FrameLengthCodec.FRAME_LENGTH_SIZE; + int r = Math.min(remaining, metadata.readableBytes()); + remaining -= r; + metadataFragment = metadata.readRetainedSlice(r); + } + + ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; + try { + if (remaining > 0 && data.isReadable()) { + int r = Math.min(remaining, data.readableBytes()); + dataFragment = data.readRetainedSlice(r); + } + } catch (IllegalReferenceCountException | NullPointerException e) { + if (metadataFragment != null) { + metadataFragment.release(); + } + throw e; + } + + switch (frameType) { + // Requester Side + case REQUEST_STREAM: + return RequestStreamFrameCodec.encode( + allocator, streamId, true, initialRequestN, metadataFragment, dataFragment); + case REQUEST_CHANNEL: + return RequestChannelFrameCodec.encode( + allocator, streamId, true, false, initialRequestN, metadataFragment, dataFragment); + default: + throw new IllegalStateException("unsupported fragment type: " + frameType); + } + } + + static int assertMtu(int mtu) { + if (mtu > 0 && mtu < MIN_MTU_SIZE || mtu < 0) { + String msg = + String.format( + "The smallest allowed mtu size is %d bytes, provided: %d", MIN_MTU_SIZE, mtu); + throw new IllegalArgumentException(msg); + } else { + return mtu; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java b/rsocket-core/src/main/java/io/rsocket/core/FrameHandler.java similarity index 67% rename from rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java rename to rsocket-core/src/main/java/io/rsocket/core/FrameHandler.java index 8cc3fb41a..6d1ee1b09 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/core/FrameHandler.java @@ -13,15 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package io.rsocket.core; -/** - * Support for frame fragmentation and reassembly. - * - * @see Fragmentation - * and Reassembly - */ -@NonNullApi -package io.rsocket.fragmentation; +import io.netty.buffer.ByteBuf; + +interface FrameHandler { + + void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload); + + void handleError(Throwable t); + + void handleComplete(); + + void handleCancel(); -import reactor.util.annotation.NonNullApi; + void handleRequestN(long n); +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/MetadataPushRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/MetadataPushRequesterMono.java new file mode 100644 index 000000000..3a53b0ad8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/MetadataPushRequesterMono.java @@ -0,0 +1,184 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValidMetadata; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; +import io.rsocket.frame.MetadataPushFrameCodec; +import io.rsocket.internal.UnboundedProcessor; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class MetadataPushRequesterMono extends Mono implements Scannable { + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(MetadataPushRequesterMono.class, "state"); + + final ByteBufAllocator allocator; + final Payload payload; + final int maxFrameLength; + final UnboundedProcessor sendProcessor; + + MetadataPushRequesterMono(Payload payload, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.sendProcessor = requesterResponderSupport.getSendProcessor(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + Operators.error( + actual, new IllegalStateException("MetadataPushMono allows only a single Subscriber")); + return; + } + + final Payload p = this.payload; + final ByteBuf metadata; + try { + final boolean hasMetadata = p.hasMetadata(); + metadata = p.metadata(); + if (!hasMetadata) { + lazyTerminate(STATE, this); + p.release(); + Operators.error( + actual, + new IllegalArgumentException("Metadata push should have metadata field present")); + return; + } + if (!isValidMetadata(this.maxFrameLength, metadata)) { + lazyTerminate(STATE, this); + p.release(); + Operators.error( + actual, + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + Operators.error(actual, e); + return; + } + + final ByteBuf metadataRetainedSlice; + try { + metadataRetainedSlice = metadata.retainedSlice(); + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + Operators.error(actual, e); + return; + } + + try { + p.release(); + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + metadataRetainedSlice.release(); + Operators.error(actual, e); + return; + } + + final ByteBuf requestFrame = + MetadataPushFrameCodec.encode(this.allocator, metadataRetainedSlice); + this.sendProcessor.onNext(requestFrame); + + Operators.complete(actual); + } + + @Override + @Nullable + public Void block(Duration m) { + return block(); + } + + @Override + @Nullable + public Void block() { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + throw new IllegalStateException("MetadataPushMono allows only a single Subscriber"); + } + + final Payload p = this.payload; + final ByteBuf metadata; + try { + final boolean hasMetadata = p.hasMetadata(); + metadata = p.metadata(); + if (hasMetadata) { + lazyTerminate(STATE, this); + p.release(); + throw new IllegalArgumentException("Metadata push does not support metadata field"); + } + if (!isValidMetadata(this.maxFrameLength, metadata)) { + lazyTerminate(STATE, this); + p.release(); + throw new IllegalArgumentException("Too Big Payload size"); + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + throw e; + } + + final ByteBuf metadataRetainedSlice; + try { + metadataRetainedSlice = metadata.retainedSlice(); + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + throw e; + } + + try { + p.release(); + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + metadataRetainedSlice.release(); + throw e; + } + + final ByteBuf requestFrame = + MetadataPushFrameCodec.encode(this.allocator, metadataRetainedSlice); + this.sendProcessor.onNext(requestFrame); + + return null; + } + + @Override + public Object scanUnsafe(Attr key) { + return null; // no particular key to be represented, still useful in hooks + } + + @Override + @NonNull + public String stepName() { + return "source(MetadataPushMono)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/MetadataPushResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/MetadataPushResponderSubscriber.java new file mode 100644 index 000000000..4c69934e8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/MetadataPushResponderSubscriber.java @@ -0,0 +1,45 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; + +final class MetadataPushResponderSubscriber implements CoreSubscriber { + static final Logger logger = LoggerFactory.getLogger(MetadataPushResponderSubscriber.class); + + static final MetadataPushResponderSubscriber INSTANCE = new MetadataPushResponderSubscriber(); + + private MetadataPushResponderSubscriber() {} + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void voidVal) {} + + @Override + public void onError(Throwable t) { + logger.debug("Dropped error", t); + } + + @Override + public void onComplete() {} +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java index 5e62105c9..6ece319c9 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java +++ b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java @@ -1,33 +1,48 @@ package io.rsocket.core; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_INITIAL_REQUEST_N; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import io.netty.buffer.ByteBuf; import io.rsocket.Payload; -import io.rsocket.frame.FrameHeaderCodec; -import io.rsocket.frame.FrameLengthCodec; final class PayloadValidationUtils { static final String INVALID_PAYLOAD_ERROR_MESSAGE = - "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory."; + "The payload is too big to be send as a single frame with a max frame length %s. Consider enabling fragmentation."; + + static boolean isValid(int mtu, int maxFrameLength, Payload payload, boolean hasInitialRequestN) { - static boolean isValid(int mtu, Payload payload, int maxFrameLength) { if (mtu > 0) { return true; } - if (payload.hasMetadata()) { - return ((FrameHeaderCodec.size() - + FrameLengthCodec.FRAME_LENGTH_SIZE - + FrameHeaderCodec.size() - + payload.data().readableBytes() - + payload.metadata().readableBytes()) - <= maxFrameLength); + final boolean hasMetadata = payload.hasMetadata(); + final ByteBuf data = payload.data(); + + int unitSize; + if (hasMetadata) { + final ByteBuf metadata = payload.metadata(); + unitSize = + (hasInitialRequestN + ? FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N + : FRAME_OFFSET_WITH_METADATA) + + metadata.readableBytes() + + // metadata payload bytes + data.readableBytes(); // data payload bytes } else { - return ((FrameHeaderCodec.size() - + payload.data().readableBytes() - + FrameLengthCodec.FRAME_LENGTH_SIZE) - <= maxFrameLength); + unitSize = + (hasInitialRequestN ? FRAME_OFFSET_WITH_INITIAL_REQUEST_N : FRAME_OFFSET) + + data.readableBytes(); // data payload bytes } + + return unitSize <= maxFrameLength; + } + + static boolean isValidMetadata(int maxFrameLength, ByteBuf metadata) { + return FRAME_OFFSET + metadata.readableBytes() <= maxFrameLength; } static void assertValidateSetup(int maxFrameLength, int maxInboundPayloadSize, int mtu) { diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java index fdb4859cf..5664eace3 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java @@ -15,7 +15,9 @@ */ package io.rsocket.core; +import static io.rsocket.core.FragmentationUtils.assertMtu; import static io.rsocket.core.PayloadValidationUtils.assertValidateSetup; +import static io.rsocket.core.ReassemblyUtils.assertInboundPayloadSize; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -25,8 +27,6 @@ import io.rsocket.RSocket; import io.rsocket.RSocketClient; import io.rsocket.SocketAcceptor; -import io.rsocket.fragmentation.FragmentationDuplexConnection; -import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.frame.SetupFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.ClientServerInputMultiplexer; @@ -48,7 +48,6 @@ import java.util.function.Supplier; import reactor.core.Disposable; import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; import reactor.util.annotation.Nullable; import reactor.util.function.Tuples; import reactor.util.retry.Retry; @@ -441,8 +440,7 @@ public RSocketConnector lease(Supplier> supplier) { * and Reassembly */ public RSocketConnector maxInboundPayloadSize(int maxInboundPayloadSize) { - this.maxInboundPayloadSize = - ReassemblyDuplexConnection.assertInboundPayloadSize(maxInboundPayloadSize); + this.maxInboundPayloadSize = assertInboundPayloadSize(maxInboundPayloadSize); return this; } @@ -460,7 +458,7 @@ public RSocketConnector maxInboundPayloadSize(int maxInboundPayloadSize) { * and Reassembly */ public RSocketConnector fragment(int mtu) { - this.mtu = FragmentationDuplexConnection.assertMtu(mtu); + this.mtu = assertMtu(mtu); return this; } @@ -576,14 +574,7 @@ private Mono connect0(Supplier transportSupplier) { assertValidateSetup(maxFrameLength, maxInboundPayloadSize, mtu); return ct; }) - .flatMap(transport -> transport.connect()) - .map( - connection -> - mtu > 0 - ? new FragmentationDuplexConnection( - connection, mtu, maxInboundPayloadSize, "client") - : new ReassemblyDuplexConnection( - connection, maxInboundPayloadSize)); + .flatMap(transport -> transport.connect()); return connectionMono .flatMap( @@ -641,11 +632,11 @@ private Mono connect0(Supplier transportSupplier) { StreamIdSupplier.clientSupplier(), mtu, maxFrameLength, + maxInboundPayloadSize, (int) keepAliveInterval.toMillis(), (int) keepAliveMaxLifeTime.toMillis(), keepAliveHandler, - requesterLeaseHandler, - Schedulers.single(Schedulers.parallel())); + requesterLeaseHandler); RSocket wrappedRSocketRequester = interceptors.initRequester(rSocketRequester); @@ -693,7 +684,8 @@ private Mono connect0(Supplier transportSupplier) { payloadDecoder, responderLeaseHandler, mtu, - maxFrameLength); + maxFrameLength, + maxInboundPayloadSize); return wrappedConnection .sendOne(setupFrame.retain()) diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index 2ecdec215..00a2cae37 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -16,102 +16,57 @@ package io.rsocket.core; -import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; import static io.rsocket.keepalive.KeepAliveSupport.ClientKeepAliveSupport; -import static io.rsocket.keepalive.KeepAliveSupport.KeepAlive; import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.util.IllegalReferenceCountException; import io.netty.util.ReferenceCountUtil; -import io.netty.util.ReferenceCounted; -import io.netty.util.collection.IntObjectMap; import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.exceptions.ConnectionErrorException; import io.rsocket.exceptions.Exceptions; -import io.rsocket.frame.CancelFrameCodec; import io.rsocket.frame.ErrorFrameCodec; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; -import io.rsocket.frame.MetadataPushFrameCodec; -import io.rsocket.frame.PayloadFrameCodec; -import io.rsocket.frame.RequestChannelFrameCodec; -import io.rsocket.frame.RequestFireAndForgetFrameCodec; import io.rsocket.frame.RequestNFrameCodec; -import io.rsocket.frame.RequestResponseFrameCodec; -import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.SynchronizedIntObjectHashMap; import io.rsocket.internal.UnboundedProcessor; import io.rsocket.keepalive.KeepAliveFramesAcceptor; import io.rsocket.keepalive.KeepAliveHandler; import io.rsocket.keepalive.KeepAliveSupport; import io.rsocket.lease.RequesterLeaseHandler; import java.nio.channels.ClosedChannelException; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; -import java.util.function.Consumer; import java.util.function.Supplier; -import org.reactivestreams.Processor; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; -import reactor.core.publisher.Operators; -import reactor.core.publisher.SignalType; -import reactor.core.publisher.UnicastProcessor; -import reactor.core.scheduler.Scheduler; import reactor.util.annotation.Nullable; -import reactor.util.concurrent.Queues; /** * Requester Side of a RSocket socket. Sends {@link ByteBuf}s to a {@link RSocketResponder} of peer */ -class RSocketRequester implements RSocket { +class RSocketRequester extends RequesterResponderSupport implements RSocket { private static final Logger LOGGER = LoggerFactory.getLogger(RSocketRequester.class); private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); - private static final Consumer DROPPED_ELEMENTS_CONSUMER = - referenceCounted -> { - if (referenceCounted.refCnt() > 0) { - try { - referenceCounted.release(); - } catch (IllegalReferenceCountException e) { - // ignored - } - } - }; static { CLOSED_CHANNEL_EXCEPTION.setStackTrace(new StackTraceElement[0]); } private volatile Throwable terminationError; - private static final AtomicReferenceFieldUpdater TERMINATION_ERROR = AtomicReferenceFieldUpdater.newUpdater( RSocketRequester.class, Throwable.class, "terminationError"); private final DuplexConnection connection; - private final PayloadDecoder payloadDecoder; - private final StreamIdSupplier streamIdSupplier; - private final IntObjectMap senders; - private final IntObjectMap> receivers; - private final UnboundedProcessor sendProcessor; - private final int mtu; - private final int maxFrameLength; private final RequesterLeaseHandler leaseHandler; - private final ByteBufAllocator allocator; private final KeepAliveFramesAcceptor keepAliveFramesAcceptor; private final MonoProcessor onClose; - private final Scheduler serialScheduler; RSocketRequester( DuplexConnection connection, @@ -119,26 +74,26 @@ class RSocketRequester implements RSocket { StreamIdSupplier streamIdSupplier, int mtu, int maxFrameLength, + int maxInboundPayloadSize, int keepAliveTickPeriod, int keepAliveAckTimeout, @Nullable KeepAliveHandler keepAliveHandler, - RequesterLeaseHandler leaseHandler, - Scheduler serialScheduler) { + RequesterLeaseHandler leaseHandler) { + super( + mtu, + maxFrameLength, + maxInboundPayloadSize, + payloadDecoder, + connection.alloc(), + streamIdSupplier); + this.connection = connection; - this.allocator = connection.alloc(); - this.payloadDecoder = payloadDecoder; - this.streamIdSupplier = streamIdSupplier; - this.mtu = mtu; - this.maxFrameLength = maxFrameLength; this.leaseHandler = leaseHandler; - this.senders = new SynchronizedIntObjectHashMap<>(); - this.receivers = new SynchronizedIntObjectHashMap<>(); this.onClose = MonoProcessor.create(); - this.serialScheduler = serialScheduler; - // DO NOT Change the order here. The Send processor must be subscribed to before receiving - this.sendProcessor = new UnboundedProcessor<>(); + UnboundedProcessor sendProcessor = super.getSendProcessor(); + // DO NOT Change the order here. The Send processor must be subscribed to before receiving connection.onClose().subscribe(null, this::tryTerminateOnConnectionError, this::tryShutdown); connection.send(sendProcessor).subscribe(null, this::handleSendProcessorError); @@ -146,7 +101,7 @@ class RSocketRequester implements RSocket { if (keepAliveTickPeriod != 0 && keepAliveHandler != null) { KeepAliveSupport keepAliveSupport = - new ClientKeepAliveSupport(this.allocator, keepAliveTickPeriod, keepAliveAckTimeout); + new ClientKeepAliveSupport(this.getAllocator(), keepAliveTickPeriod, keepAliveAckTimeout); this.keepAliveFramesAcceptor = keepAliveHandler.start( keepAliveSupport, sendProcessor::onNextPrioritized, this::tryTerminateOnKeepAlive); @@ -157,477 +112,88 @@ class RSocketRequester implements RSocket { @Override public Mono fireAndForget(Payload payload) { - return handleFireAndForget(payload); + return new FireAndForgetRequesterMono(payload, this); } @Override public Mono requestResponse(Payload payload) { - return handleRequestResponse(payload); + return new RequestResponseRequesterMono(payload, this); } @Override public Flux requestStream(Payload payload) { - return handleRequestStream(payload); + return new RequestStreamRequesterFlux(payload, this); } @Override public Flux requestChannel(Publisher payloads) { - return handleChannel(Flux.from(payloads)); + return new RequestChannelRequesterFlux(payloads, this); } @Override public Mono metadataPush(Payload payload) { - return handleMetadataPush(payload); - } - - @Override - public double availability() { - return Math.min(connection.availability(), leaseHandler.availability()); - } - - @Override - public void dispose() { - tryShutdown(); - } - - @Override - public boolean isDisposed() { - return terminationError != null; - } - - @Override - public Mono onClose() { - return onClose; - } - - private Mono handleFireAndForget(Payload payload) { - if (payload.refCnt() <= 0) { - return Mono.error(new IllegalReferenceCountException()); - } - - if (isDisposed()) { + Throwable terminationError = this.terminationError; + if (terminationError != null) { payload.release(); - final Throwable t = terminationError; - return Mono.error(t); + return Mono.error(terminationError); } - if (!PayloadValidationUtils.isValid(this.mtu, payload, maxFrameLength)) { - payload.release(); - return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); - } - - final AtomicBoolean once = new AtomicBoolean(); - - return Mono.defer( - () -> { - if (once.getAndSet(true)) { - return Mono.error( - new IllegalStateException("FireAndForgetMono allows only a single subscriber")); - } - - if (isDisposed()) { - payload.release(); - final Throwable t = terminationError; - return Mono.error(t); - } - - RequesterLeaseHandler lh = leaseHandler; - if (!lh.useLease()) { - payload.release(); - return Mono.error(lh.leaseError()); - } - - final int streamId = streamIdSupplier.nextStreamId(receivers); - final ByteBuf requestFrame = - RequestFireAndForgetFrameCodec.encodeReleasingPayload( - allocator, streamId, payload); - - sendProcessor.onNext(requestFrame); - - return Mono.empty(); - }) - .subscribeOn(serialScheduler); + return new MetadataPushRequesterMono(payload, this); } - private Mono handleRequestResponse(final Payload payload) { - if (payload.refCnt() <= 0) { - return Mono.error(new IllegalReferenceCountException()); + @Override + public int getNextStreamId() { + RequesterLeaseHandler leaseHandler = this.leaseHandler; + if (!leaseHandler.useLease()) { + throw reactor.core.Exceptions.propagate(leaseHandler.leaseError()); } - if (isDisposed()) { - payload.release(); - final Throwable t = terminationError; - return Mono.error(t); - } + int nextStreamId = super.getNextStreamId(); - if (!PayloadValidationUtils.isValid(this.mtu, payload, maxFrameLength)) { - payload.release(); - return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + Throwable terminationError = this.terminationError; + if (terminationError != null) { + throw reactor.core.Exceptions.propagate(terminationError); } - final UnboundedProcessor sendProcessor = this.sendProcessor; - final UnicastProcessor receiver = UnicastProcessor.create(Queues.one().get()); - final AtomicBoolean once = new AtomicBoolean(); - - return Mono.defer( - () -> { - if (once.getAndSet(true)) { - return Mono.error( - new IllegalStateException("RequestResponseMono allows only a single subscriber")); - } - - return receiver - .next() - .transform( - Operators.lift( - (s, actual) -> - new RequestOperator(actual) { - - @Override - void hookOnFirstRequest(long n) { - if (isDisposed()) { - payload.release(); - final Throwable t = terminationError; - receiver.onError(t); - return; - } - - RequesterLeaseHandler lh = leaseHandler; - if (!lh.useLease()) { - payload.release(); - receiver.onError(lh.leaseError()); - return; - } - - int streamId = streamIdSupplier.nextStreamId(receivers); - this.streamId = streamId; - - ByteBuf requestResponseFrame = - RequestResponseFrameCodec.encodeReleasingPayload( - allocator, streamId, payload); - - receivers.put(streamId, receiver); - sendProcessor.onNext(requestResponseFrame); - } - - @Override - void hookOnCancel() { - if (receivers.remove(streamId, receiver)) { - sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); - } else { - if (this.firstRequest) { - payload.release(); - } - } - } - - @Override - public void hookOnTerminal(SignalType signalType) { - receivers.remove(streamId, receiver); - } - })) - .subscribeOn(serialScheduler) - .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); - }); + return nextStreamId; } - private Flux handleRequestStream(final Payload payload) { - if (payload.refCnt() <= 0) { - return Flux.error(new IllegalReferenceCountException()); + @Override + public int addAndGetNextStreamId(FrameHandler frameHandler) { + RequesterLeaseHandler leaseHandler = this.leaseHandler; + if (!leaseHandler.useLease()) { + throw reactor.core.Exceptions.propagate(leaseHandler.leaseError()); } - if (isDisposed()) { - payload.release(); - final Throwable t = terminationError; - return Flux.error(t); - } + int nextStreamId = super.addAndGetNextStreamId(frameHandler); - if (!PayloadValidationUtils.isValid(this.mtu, payload, maxFrameLength)) { - payload.release(); - return Flux.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + Throwable terminationError = this.terminationError; + if (terminationError != null) { + super.remove(nextStreamId, frameHandler); + throw reactor.core.Exceptions.propagate(terminationError); } - final UnboundedProcessor sendProcessor = this.sendProcessor; - final UnicastProcessor receiver = UnicastProcessor.create(); - final AtomicBoolean once = new AtomicBoolean(); - - return Flux.defer( - () -> { - if (once.getAndSet(true)) { - return Flux.error( - new IllegalStateException("RequestStreamFlux allows only a single subscriber")); - } - - return receiver - .transform( - Operators.lift( - (s, actual) -> - new RequestOperator(actual) { - - @Override - void hookOnFirstRequest(long n) { - if (isDisposed()) { - payload.release(); - final Throwable t = terminationError; - receiver.onError(t); - return; - } - - RequesterLeaseHandler lh = leaseHandler; - if (!lh.useLease()) { - payload.release(); - receiver.onError(lh.leaseError()); - return; - } - - int streamId = streamIdSupplier.nextStreamId(receivers); - this.streamId = streamId; - - ByteBuf requestStreamFrame = - RequestStreamFrameCodec.encodeReleasingPayload( - allocator, streamId, n, payload); - - receivers.put(streamId, receiver); - - sendProcessor.onNext(requestStreamFrame); - } - - @Override - void hookOnRemainingRequests(long n) { - if (receiver.isDisposed()) { - return; - } - - sendProcessor.onNext( - RequestNFrameCodec.encode(allocator, streamId, n)); - } - - @Override - void hookOnCancel() { - if (receivers.remove(streamId, receiver)) { - sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); - } else { - if (this.firstRequest) { - payload.release(); - } - } - } - - @Override - void hookOnTerminal(SignalType signalType) { - receivers.remove(streamId); - } - })) - .subscribeOn(serialScheduler, false) - .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); - }); + return nextStreamId; } - private Flux handleChannel(Flux request) { - if (isDisposed()) { - final Throwable t = terminationError; - return Flux.error(t); - } - - return request - .switchOnFirst( - (s, flux) -> { - Payload payload = s.get(); - if (payload != null) { - if (payload.refCnt() <= 0) { - return Mono.error(new IllegalReferenceCountException()); - } - - if (!PayloadValidationUtils.isValid(mtu, payload, maxFrameLength)) { - payload.release(); - final IllegalArgumentException t = - new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); - return Mono.error(t); - } - return handleChannel(payload, flux); - } else { - return flux; - } - }, - false) - .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); + @Override + public double availability() { + return Math.min(connection.availability(), leaseHandler.availability()); } - private Flux handleChannel(Payload initialPayload, Flux inboundFlux) { - final UnboundedProcessor sendProcessor = this.sendProcessor; - - final UnicastProcessor receiver = UnicastProcessor.create(); - - return receiver - .transform( - Operators.lift( - (s, actual) -> - new RequestOperator(actual) { - - final BaseSubscriber upstreamSubscriber = - new BaseSubscriber() { - - boolean first = true; - - @Override - protected void hookOnSubscribe(Subscription subscription) { - // noops - } - - @Override - protected void hookOnNext(Payload payload) { - if (first) { - // need to skip first since we have already sent it - // no need to release it since it was released earlier on the - // request - // establishment - // phase - first = false; - request(1); - return; - } - if (!PayloadValidationUtils.isValid(mtu, payload, maxFrameLength)) { - payload.release(); - cancel(); - final IllegalArgumentException t = - new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); - // no need to send any errors. - sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); - receiver.onError(t); - return; - } - final ByteBuf frame = - PayloadFrameCodec.encodeNextReleasingPayload( - allocator, streamId, payload); - - sendProcessor.onNext(frame); - } - - @Override - protected void hookOnComplete() { - ByteBuf frame = PayloadFrameCodec.encodeComplete(allocator, streamId); - sendProcessor.onNext(frame); - } - - @Override - protected void hookOnError(Throwable t) { - ByteBuf frame = ErrorFrameCodec.encode(allocator, streamId, t); - sendProcessor.onNext(frame); - receiver.onError(t); - } - - @Override - protected void hookFinally(SignalType type) { - senders.remove(streamId, this); - } - }; - - @Override - void hookOnFirstRequest(long n) { - if (isDisposed()) { - initialPayload.release(); - final Throwable t = terminationError; - upstreamSubscriber.cancel(); - receiver.onError(t); - return; - } - - RequesterLeaseHandler lh = leaseHandler; - if (!lh.useLease()) { - initialPayload.release(); - receiver.onError(lh.leaseError()); - return; - } - - final int streamId = streamIdSupplier.nextStreamId(receivers); - this.streamId = streamId; - - final ByteBuf frame = - RequestChannelFrameCodec.encodeReleasingPayload( - allocator, streamId, false, n, initialPayload); - - senders.put(streamId, upstreamSubscriber); - receivers.put(streamId, receiver); - - inboundFlux - .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER) - .subscribe(upstreamSubscriber); - - sendProcessor.onNext(frame); - } - - @Override - void hookOnRemainingRequests(long n) { - if (receiver.isDisposed()) { - return; - } - - sendProcessor.onNext(RequestNFrameCodec.encode(allocator, streamId, n)); - } - - @Override - void hookOnCancel() { - senders.remove(streamId, upstreamSubscriber); - if (receivers.remove(streamId, receiver)) { - sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); - } - } - - @Override - void hookOnTerminal(SignalType signalType) { - if (signalType == SignalType.ON_ERROR) { - upstreamSubscriber.cancel(); - } - receivers.remove(streamId, receiver); - } - - @Override - public void cancel() { - upstreamSubscriber.cancel(); - super.cancel(); - } - })) - .subscribeOn(serialScheduler, false); + @Override + public void dispose() { + tryShutdown(); } - private Mono handleMetadataPush(Payload payload) { - if (payload.refCnt() <= 0) { - return Mono.error(new IllegalReferenceCountException()); - } - - if (isDisposed()) { - Throwable err = this.terminationError; - payload.release(); - return Mono.error(err); - } - - if (!PayloadValidationUtils.isValid(this.mtu, payload, maxFrameLength)) { - payload.release(); - return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); - } - - final AtomicBoolean once = new AtomicBoolean(); - - return Mono.defer( - () -> { - if (once.getAndSet(true)) { - return Mono.error( - new IllegalStateException("MetadataPushMono allows only a single subscriber")); - } - - if (isDisposed()) { - payload.release(); - final Throwable t = terminationError; - return Mono.error(t); - } - - ByteBuf metadataPushFrame = - MetadataPushFrameCodec.encodeReleasingPayload(allocator, payload); - - sendProcessor.onNextPrioritized(metadataPushFrame); + @Override + public boolean isDisposed() { + return terminationError != null; + } - return Mono.empty(); - }); + @Override + public Mono onClose() { + return onClose; } private void handleIncomingFrames(ByteBuf frame) { @@ -668,79 +234,42 @@ private void handleStreamZero(FrameType type, ByteBuf frame) { } private void handleFrame(int streamId, FrameType type, ByteBuf frame) { - Subscriber receiver = receivers.get(streamId); + FrameHandler receiver = this.get(streamId); + if (receiver == null) { + handleMissingResponseProcessor(streamId, type, frame); + return; + } + switch (type) { - case NEXT: - if (receiver == null) { - handleMissingResponseProcessor(streamId, type, frame); - return; - } - receiver.onNext(payloadDecoder.apply(frame)); - break; case NEXT_COMPLETE: - if (receiver == null) { - handleMissingResponseProcessor(streamId, type, frame); - return; - } - receiver.onNext(payloadDecoder.apply(frame)); - receiver.onComplete(); + receiver.handleNext(frame, false, true); + break; + case NEXT: + boolean hasFollows = FrameHeaderCodec.hasFollows(frame); + receiver.handleNext(frame, hasFollows, false); break; case COMPLETE: - if (receiver == null) { - handleMissingResponseProcessor(streamId, type, frame); - return; - } - receiver.onComplete(); - receivers.remove(streamId); + receiver.handleComplete(); break; case ERROR: - if (receiver == null) { - handleMissingResponseProcessor(streamId, type, frame); - return; - } - - // FIXME: when https://github.com/reactor/reactor-core/issues/2176 is resolved - // This is workaround to handle specific Reactor related case when - // onError call may not return normally - try { - receiver.onError(Exceptions.from(streamId, frame)); - } catch (RuntimeException e) { - if (reactor.core.Exceptions.isBubbling(e) - || reactor.core.Exceptions.isErrorCallbackNotImplemented(e)) { - if (LOGGER.isDebugEnabled()) { - Throwable unwrapped = reactor.core.Exceptions.unwrap(e); - LOGGER.debug("Unhandled dropped exception", unwrapped); - } - } - } - - receivers.remove(streamId); + receiver.handleError(Exceptions.from(streamId, frame)); break; case CANCEL: - { - Subscription sender = senders.remove(streamId); - if (sender != null) { - sender.cancel(); - } - break; - } + receiver.handleCancel(); + break; case REQUEST_N: - { - Subscription sender = senders.get(streamId); - if (sender != null) { - long n = RequestNFrameCodec.requestN(frame); - sender.request(n); - } - break; - } + long n = RequestNFrameCodec.requestN(frame); + receiver.handleRequestN(n); + break; default: throw new IllegalStateException( "Requester received unsupported frame on stream " + streamId + ": " + frame.toString()); } } + @SuppressWarnings("ConstantConditions") private void handleMissingResponseProcessor(int streamId, FrameType type, ByteBuf frame) { - if (!streamIdSupplier.isBeforeOrCurrent(streamId)) { + if (!super.streamIdSupplier.isBeforeOrCurrent(streamId)) { if (type == FrameType.ERROR) { // message for stream that has never existed, we have a problem with // the overall connection and must tear down @@ -763,7 +292,7 @@ private void handleMissingResponseProcessor(int streamId, FrameType type, ByteBu // so ignore (cancellation is async so there is a race condition) } - private void tryTerminateOnKeepAlive(KeepAlive keepAlive) { + private void tryTerminateOnKeepAlive(KeepAliveSupport.KeepAlive keepAlive) { tryTerminate( () -> new ConnectionErrorException( @@ -782,7 +311,7 @@ private void tryTerminate(Supplier errorSupplier) { if (terminationError == null) { Throwable e = errorSupplier.get(); if (TERMINATION_ERROR.compareAndSet(this, null, e)) { - serialScheduler.schedule(() -> terminate(e)); + terminate(e); } } } @@ -790,7 +319,7 @@ private void tryTerminate(Supplier errorSupplier) { private void tryShutdown() { if (terminationError == null) { if (TERMINATION_ERROR.compareAndSet(this, null, CLOSED_CHANNEL_EXCEPTION)) { - serialScheduler.schedule(() -> terminate(CLOSED_CHANNEL_EXCEPTION)); + terminate(CLOSED_CHANNEL_EXCEPTION); } } } @@ -802,33 +331,19 @@ private void terminate(Throwable e) { connection.dispose(); leaseHandler.dispose(); - receivers - .values() - .forEach( - receiver -> { - try { - receiver.onError(e); - } catch (Throwable t) { - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Dropped exception", t); - } - } - }); - senders - .values() - .forEach( - sender -> { - try { - sender.cancel(); - } catch (Throwable t) { - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Dropped exception", t); + synchronized (this) { + activeStreams + .values() + .forEach( + receiver -> { + try { + receiver.handleError(e); + } catch (Throwable ignored) { } - } - }); - senders.clear(); - receivers.clear(); - sendProcessor.dispose(); + }); + } + + this.getSendProcessor().dispose(); if (e == CLOSED_CHANNEL_EXCEPTION) { onClose.onComplete(); } else { diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java index 581605ff4..8daea299a 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -16,53 +16,38 @@ package io.rsocket.core; -import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; - import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.util.IllegalReferenceCountException; import io.netty.util.ReferenceCountUtil; -import io.netty.util.ReferenceCounted; -import io.netty.util.collection.IntObjectMap; import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.frame.*; +import io.rsocket.ResponderRSocket; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.SynchronizedIntObjectHashMap; import io.rsocket.internal.UnboundedProcessor; import io.rsocket.lease.ResponderLeaseHandler; import java.nio.channels.ClosedChannelException; import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; -import java.util.function.Consumer; -import java.util.function.LongConsumer; import java.util.function.Supplier; -import org.reactivestreams.Processor; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.Disposable; import reactor.core.Exceptions; -import reactor.core.publisher.*; -import reactor.util.annotation.Nullable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; /** Responder side of RSocket. Receives {@link ByteBuf}s from a peer's {@link RSocketRequester} */ -class RSocketResponder implements RSocket { +class RSocketResponder extends RequesterResponderSupport implements RSocket { + private static final Logger LOGGER = LoggerFactory.getLogger(RSocketResponder.class); - private static final Consumer DROPPED_ELEMENTS_CONSUMER = - referenceCounted -> { - if (referenceCounted.refCnt() > 0) { - try { - referenceCounted.release(); - } catch (IllegalReferenceCountException e) { - // ignored - } - } - }; private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); private final DuplexConnection connection; @@ -71,7 +56,6 @@ class RSocketResponder implements RSocket { @SuppressWarnings("deprecation") private final io.rsocket.ResponderRSocket responderRSocket; - private final PayloadDecoder payloadDecoder; private final ResponderLeaseHandler leaseHandler; private final Disposable leaseHandlerDisposable; @@ -80,26 +64,16 @@ class RSocketResponder implements RSocket { AtomicReferenceFieldUpdater.newUpdater( RSocketResponder.class, Throwable.class, "terminationError"); - private final int mtu; - private final int maxFrameLength; - - private final IntObjectMap sendingSubscriptions; - private final IntObjectMap> channelProcessors; - - private final UnboundedProcessor sendProcessor; - private final ByteBufAllocator allocator; - RSocketResponder( DuplexConnection connection, RSocket requestHandler, PayloadDecoder payloadDecoder, ResponderLeaseHandler leaseHandler, int mtu, - int maxFrameLength) { + int maxFrameLength, + int maxInboundPayloadSize) { + super(mtu, maxFrameLength, maxInboundPayloadSize, payloadDecoder, connection.alloc(), null); this.connection = connection; - this.allocator = connection.alloc(); - this.mtu = mtu; - this.maxFrameLength = maxFrameLength; this.requestHandler = requestHandler; this.responderRSocket = @@ -107,14 +81,11 @@ class RSocketResponder implements RSocket { ? (io.rsocket.ResponderRSocket) requestHandler : null; - this.payloadDecoder = payloadDecoder; this.leaseHandler = leaseHandler; - this.sendingSubscriptions = new SynchronizedIntObjectHashMap<>(); - this.channelProcessors = new SynchronizedIntObjectHashMap<>(); // DO NOT Change the order here. The Send processor must be subscribed to before receiving // connections - this.sendProcessor = new UnboundedProcessor<>(); + UnboundedProcessor sendProcessor = super.getSendProcessor(); connection.send(sendProcessor).subscribe(null, this::handleSendProcessorError); @@ -127,31 +98,9 @@ class RSocketResponder implements RSocket { } private void handleSendProcessorError(Throwable t) { - sendingSubscriptions - .values() - .forEach( - subscription -> { - try { - subscription.cancel(); - } catch (Throwable e) { - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Dropped exception", t); - } - } - }); - - channelProcessors - .values() - .forEach( - subscription -> { - try { - subscription.onError(t); - } catch (Throwable e) { - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Dropped exception", t); - } - } - }); + for (FrameHandler frameHandler : activeStreams.values()) { + frameHandler.handleError(t); + } } private void tryTerminateOnConnectionError(Throwable e) { @@ -229,7 +178,12 @@ public Flux requestChannel(Publisher payloads) { private Flux requestChannel(Payload payload, Publisher payloads) { try { if (leaseHandler.useLease()) { - return responderRSocket.requestChannel(payload, payloads); + final ResponderRSocket responderRSocket = this.responderRSocket; + if (responderRSocket != null) { + return responderRSocket.requestChannel(payload, payloads); + } else { + return requestHandler.requestChannel(payloads); + } } else { payload.release(); return Flux.error(leaseHandler.leaseError()); @@ -265,113 +219,99 @@ public Mono onClose() { private void cleanup(Throwable e) { cleanUpSendingSubscriptions(); - cleanUpChannelProcessors(e); connection.dispose(); leaseHandlerDisposable.dispose(); requestHandler.dispose(); - sendProcessor.dispose(); + super.getSendProcessor().dispose(); } private synchronized void cleanUpSendingSubscriptions() { - sendingSubscriptions.values().forEach(Subscription::cancel); - sendingSubscriptions.clear(); - } - - private synchronized void cleanUpChannelProcessors(Throwable e) { - channelProcessors - .values() - .forEach( - payloadPayloadProcessor -> { - try { - payloadPayloadProcessor.onError(e); - } catch (Throwable t) { - // noops - } - }); - channelProcessors.clear(); + activeStreams.values().forEach(FrameHandler::handleCancel); + activeStreams.clear(); } private void handleFrame(ByteBuf frame) { try { int streamId = FrameHeaderCodec.streamId(frame); - Subscriber receiver; + FrameHandler receiver; FrameType frameType = FrameHeaderCodec.frameType(frame); switch (frameType) { case REQUEST_FNF: - handleFireAndForget(streamId, fireAndForget(payloadDecoder.apply(frame))); + handleFireAndForget(streamId, frame); break; case REQUEST_RESPONSE: - handleRequestResponse(streamId, requestResponse(payloadDecoder.apply(frame))); - break; - case CANCEL: - handleCancelFrame(streamId); - break; - case REQUEST_N: - handleRequestN(streamId, frame); + handleRequestResponse(streamId, frame); break; case REQUEST_STREAM: long streamInitialRequestN = RequestStreamFrameCodec.initialRequestN(frame); - Payload streamPayload = payloadDecoder.apply(frame); - handleStream(streamId, requestStream(streamPayload), streamInitialRequestN, null); + handleStream(streamId, frame, streamInitialRequestN); break; case REQUEST_CHANNEL: long channelInitialRequestN = RequestChannelFrameCodec.initialRequestN(frame); - Payload channelPayload = payloadDecoder.apply(frame); - handleChannel(streamId, channelPayload, channelInitialRequestN); + handleChannel(streamId, frame, channelInitialRequestN); break; case METADATA_PUSH: - handleMetadataPush(metadataPush(payloadDecoder.apply(frame))); + handleMetadataPush(metadataPush(super.getPayloadDecoder().apply(frame))); + break; + case CANCEL: + receiver = super.get(streamId); + if (receiver != null) { + receiver.handleCancel(); + } + break; + case REQUEST_N: + receiver = super.get(streamId); + if (receiver != null) { + long n = RequestNFrameCodec.requestN(frame); + receiver.handleRequestN(n); + } break; case PAYLOAD: // TODO: Hook in receiving socket. break; case NEXT: - receiver = channelProcessors.get(streamId); + receiver = super.get(streamId); if (receiver != null) { - receiver.onNext(payloadDecoder.apply(frame)); + boolean hasFollows = FrameHeaderCodec.hasFollows(frame); + receiver.handleNext(frame, hasFollows, false); } break; case COMPLETE: - receiver = channelProcessors.get(streamId); + receiver = super.get(streamId); if (receiver != null) { - receiver.onComplete(); + receiver.handleComplete(); } break; case ERROR: - receiver = channelProcessors.get(streamId); + receiver = super.get(streamId); if (receiver != null) { - // FIXME: when https://github.com/reactor/reactor-core/issues/2176 is resolved - // This is workaround to handle specific Reactor related case when - // onError call may not return normally - try { - receiver.onError(io.rsocket.exceptions.Exceptions.from(streamId, frame)); - } catch (RuntimeException e) { - if (reactor.core.Exceptions.isBubbling(e) - || reactor.core.Exceptions.isErrorCallbackNotImplemented(e)) { - if (LOGGER.isDebugEnabled()) { - Throwable unwrapped = reactor.core.Exceptions.unwrap(e); - LOGGER.debug("Unhandled dropped exception", unwrapped); - } - } - } + receiver.handleError(io.rsocket.exceptions.Exceptions.from(streamId, frame)); } break; case NEXT_COMPLETE: - receiver = channelProcessors.get(streamId); + receiver = super.get(streamId); if (receiver != null) { - receiver.onNext(payloadDecoder.apply(frame)); - receiver.onComplete(); + receiver.handleNext(frame, false, true); } break; case SETUP: - handleError(streamId, new IllegalStateException("Setup frame received post setup.")); + super.getSendProcessor() + .onNext( + ErrorFrameCodec.encode( + super.getAllocator(), + streamId, + new IllegalStateException("Setup frame received post setup."))); break; case LEASE: default: - handleError( - streamId, - new IllegalStateException("ServerRSocket: Unexpected frame type: " + frameType)); + super.getSendProcessor() + .onNext( + ErrorFrameCodec.encode( + super.getAllocator(), + streamId, + new IllegalStateException( + "ServerRSocket: Unexpected frame type: " + frameType))); break; } ReferenceCountUtil.safeRelease(frame); @@ -381,252 +321,82 @@ private void handleFrame(ByteBuf frame) { } } - private void handleFireAndForget(int streamId, Mono result) { - result.subscribe( - new BaseSubscriber() { - @Override - protected void hookOnSubscribe(Subscription subscription) { - sendingSubscriptions.put(streamId, subscription); - subscription.request(Long.MAX_VALUE); - } - - @Override - protected void hookOnError(Throwable throwable) {} + private void handleFireAndForget(int streamId, ByteBuf frame) { + if (FrameHeaderCodec.hasFollows(frame)) { + FireAndForgetResponderSubscriber subscriber = + new FireAndForgetResponderSubscriber(streamId, frame, this, this); - @Override - protected void hookFinally(SignalType type) { - sendingSubscriptions.remove(streamId); - } - }); + this.add(streamId, subscriber); + } else { + fireAndForget(super.getPayloadDecoder().apply(frame)) + .subscribe(FireAndForgetResponderSubscriber.INSTANCE); + } } - private void handleRequestResponse(int streamId, Mono response) { - final BaseSubscriber subscriber = - new BaseSubscriber() { - private boolean isEmpty = true; - - @Override - protected void hookOnNext(Payload payload) { - if (isEmpty) { - isEmpty = false; - } - - if (!PayloadValidationUtils.isValid(mtu, payload, maxFrameLength)) { - payload.release(); - cancel(); - final IllegalArgumentException t = - new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); - handleError(streamId, t); - return; - } - - ByteBuf byteBuf = - PayloadFrameCodec.encodeNextCompleteReleasingPayload(allocator, streamId, payload); - sendProcessor.onNext(byteBuf); - } - - @Override - protected void hookOnError(Throwable throwable) { - if (sendingSubscriptions.remove(streamId, this)) { - handleError(streamId, throwable); - } - } + private void handleRequestResponse(int streamId, ByteBuf frame) { + if (FrameHeaderCodec.hasFollows(frame)) { + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber(streamId, frame, this, this); - @Override - protected void hookOnComplete() { - if (isEmpty) { - if (sendingSubscriptions.remove(streamId, this)) { - sendProcessor.onNext(PayloadFrameCodec.encodeComplete(allocator, streamId)); - } - } - } - }; + this.add(streamId, subscriber); + } else { + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber(streamId, this); - sendingSubscriptions.put(streamId, subscriber); - response.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER).subscribe(subscriber); + if (this.add(streamId, subscriber)) { + this.requestResponse(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + } + } } - private void handleStream( - int streamId, - Flux response, - long initialRequestN, - @Nullable UnicastProcessor requestChannel) { - final BaseSubscriber subscriber = - new BaseSubscriber() { - - @Override - protected void hookOnSubscribe(Subscription s) { - s.request(initialRequestN); - } - - @Override - protected void hookOnNext(Payload payload) { - try { - if (!PayloadValidationUtils.isValid(mtu, payload, maxFrameLength)) { - payload.release(); - final IllegalArgumentException t = - new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); - - cancelStream(t); - return; - } - - ByteBuf byteBuf = - PayloadFrameCodec.encodeNextReleasingPayload(allocator, streamId, payload); - sendProcessor.onNext(byteBuf); - } catch (Throwable e) { - cancelStream(e); - } - } - - private void cancelStream(Throwable t) { - // Cancel the output stream and send an ERROR frame but do not dispose the - // requestChannel (i.e. close the connection) since the spec allows to leave - // the channel in half-closed state. - // specifically for requestChannel case so when Payload is invalid we will not be - // sending CancelFrame and ErrorFrame - // Note: CancelFrame is redundant and due to spec - // (https://github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel) - // Upon receiving an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream - // is terminated on both Requester and Responder. - // Upon sending an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream is - // terminated on both the Requester and Responder. - if (requestChannel != null) { - channelProcessors.remove(streamId, requestChannel); - } - cancel(); - handleError(streamId, t); - } - - @Override - protected void hookOnComplete() { - if (sendingSubscriptions.remove(streamId, this)) { - sendProcessor.onNext(PayloadFrameCodec.encodeComplete(allocator, streamId)); - } - } - - @Override - protected void hookOnError(Throwable throwable) { - if (sendingSubscriptions.remove(streamId, this)) { - // specifically for requestChannel case so when Payload is invalid we will not be - // sending CancelFrame and ErrorFrame - // Note: CancelFrame is redundant and due to spec - // (https://github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel) - // Upon receiving an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream - // is terminated on both Requester and Responder. - // Upon sending an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream is - // terminated on both the Requester and Responder. - if (requestChannel != null && !requestChannel.isDisposed()) { - if (channelProcessors.remove(streamId, requestChannel)) { - try { - requestChannel.dispose(); - } catch (Throwable e) { - // ignore to ensure it does not blows up if it racing with async - // cancel - } - } - } - - handleError(streamId, throwable); - } - } - }; + private void handleStream(int streamId, ByteBuf frame, long initialRequestN) { + if (FrameHeaderCodec.hasFollows(frame)) { + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber(streamId, initialRequestN, frame, this, this); - sendingSubscriptions.put(streamId, subscriber); - response.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER).subscribe(subscriber); - } - - private void handleChannel(int streamId, Payload payload, long initialRequestN) { - UnicastProcessor frames = UnicastProcessor.create(); - channelProcessors.put(streamId, frames); - - Flux payloads = - frames - .doOnRequest( - new LongConsumer() { - boolean first = true; - - @Override - public void accept(long l) { - long n; - if (first) { - first = false; - n = l - 1L; - } else { - n = l; - } - if (n > 0) { - sendProcessor.onNext(RequestNFrameCodec.encode(allocator, streamId, n)); - } - } - }) - .doFinally( - signalType -> { - if (channelProcessors.remove(streamId, frames)) { - if (signalType == SignalType.CANCEL) { - sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); - } else if (signalType == SignalType.ON_ERROR) { - Subscription subscription = sendingSubscriptions.remove(streamId); - if (subscription != null) { - subscription.cancel(); - } - } - } - }) - .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); - - // not chained, as the payload should be enqueued in the Unicast processor before this method - // returns - // and any later payload can be processed - frames.onNext(payload); - - if (responderRSocket != null) { - handleStream(streamId, requestChannel(payload, payloads), initialRequestN, frames); + this.add(streamId, subscriber); } else { - handleStream(streamId, requestChannel(payloads), initialRequestN, frames); + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber(streamId, initialRequestN, this); + + if (this.add(streamId, subscriber)) { + this.requestStream(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + } } } - private void handleMetadataPush(Mono result) { - result.subscribe( - new BaseSubscriber() { - @Override - protected void hookOnSubscribe(Subscription subscription) { - subscription.request(Long.MAX_VALUE); - } - - @Override - protected void hookOnError(Throwable throwable) {} - }); - } + private void handleChannel(int streamId, ByteBuf frame, long initialRequestN) { + if (FrameHeaderCodec.hasFollows(frame)) { + RequestChannelResponderSubscriber subscriber = + new RequestChannelResponderSubscriber(streamId, initialRequestN, frame, this, this); - private void handleCancelFrame(int streamId) { - Subscription subscription = sendingSubscriptions.remove(streamId); - Processor processor = channelProcessors.remove(streamId); + this.add(streamId, subscriber); + } else { + final Payload firstPayload = super.getPayloadDecoder().apply(frame); + RequestChannelResponderSubscriber subscriber = + new RequestChannelResponderSubscriber(streamId, initialRequestN, firstPayload, this); - if (processor != null) { - try { - processor.onError(new CancellationException("Disposed")); - } catch (Exception e) { - // ignore + if (this.add(streamId, subscriber)) { + this.requestChannel(firstPayload, subscriber).subscribe(subscriber); } } - - if (subscription != null) { - subscription.cancel(); - } } - private void handleError(int streamId, Throwable t) { - sendProcessor.onNext(ErrorFrameCodec.encode(allocator, streamId, t)); + private void handleMetadataPush(Mono result) { + result.subscribe(MetadataPushResponderSubscriber.INSTANCE); } - private void handleRequestN(int streamId, ByteBuf frame) { - Subscription subscription = sendingSubscriptions.get(streamId); + private boolean add(int streamId, FrameHandler frameHandler) { + FrameHandler existingHandler; + synchronized (this) { + existingHandler = super.activeStreams.putIfAbsent(streamId, frameHandler); + } - if (subscription != null) { - long n = RequestNFrameCodec.requestN(frame); - subscription.request(n); + if (existingHandler != null) { + frameHandler.handleCancel(); + return false; } + + return true; } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java index 610636f02..3d72d9a0c 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java @@ -16,7 +16,9 @@ package io.rsocket.core; +import static io.rsocket.core.FragmentationUtils.assertMtu; import static io.rsocket.core.PayloadValidationUtils.assertValidateSetup; +import static io.rsocket.core.ReassemblyUtils.assertInboundPayloadSize; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; import io.netty.buffer.ByteBuf; @@ -28,8 +30,6 @@ import io.rsocket.SocketAcceptor; import io.rsocket.exceptions.InvalidSetupException; import io.rsocket.exceptions.RejectedSetupException; -import io.rsocket.fragmentation.FragmentationDuplexConnection; -import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.SetupFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; @@ -45,7 +45,6 @@ import java.util.function.Consumer; import java.util.function.Supplier; import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; /** * The main class for starting an RSocket server. @@ -218,8 +217,7 @@ public RSocketServer lease(Supplier> supplier) { * and Reassembly */ public RSocketServer maxInboundPayloadSize(int maxInboundPayloadSize) { - this.maxInboundPayloadSize = - ReassemblyDuplexConnection.assertInboundPayloadSize(maxInboundPayloadSize); + this.maxInboundPayloadSize = assertInboundPayloadSize(maxInboundPayloadSize); return this; } @@ -237,7 +235,7 @@ public RSocketServer maxInboundPayloadSize(int maxInboundPayloadSize) { * and Reassembly */ public RSocketServer fragment(int mtu) { - this.mtu = FragmentationDuplexConnection.assertMtu(mtu); + this.mtu = assertMtu(mtu); return this; } @@ -337,10 +335,6 @@ public Mono apply(DuplexConnection connection) { private Mono acceptor( ServerSetup serverSetup, DuplexConnection connection, int maxFrameLength) { - connection = - mtu > 0 - ? new FragmentationDuplexConnection(connection, mtu, maxInboundPayloadSize, "server") - : new ReassemblyDuplexConnection(connection, maxInboundPayloadSize); ClientServerInputMultiplexer multiplexer = new ClientServerInputMultiplexer(connection, interceptors, false); @@ -430,11 +424,11 @@ private Mono acceptSetup( StreamIdSupplier.serverSupplier(), mtu, maxFrameLength, + maxInboundPayloadSize, setupPayload.keepAliveInterval(), setupPayload.keepAliveMaxLifetime(), keepAliveHandler, - requesterLeaseHandler, - Schedulers.single(Schedulers.parallel())); + requesterLeaseHandler); RSocket wrappedRSocketRequester = interceptors.initRequester(rSocketRequester); @@ -464,7 +458,8 @@ private Mono acceptSetup( payloadDecoder, responderLeaseHandler, mtu, - maxFrameLength); + maxFrameLength, + maxInboundPayloadSize); }) .doFinally(signalType -> setupPayload.release()) .then(); diff --git a/rsocket-core/src/main/java/io/rsocket/core/ReassemblyUtils.java b/rsocket-core/src/main/java/io/rsocket/core/ReassemblyUtils.java new file mode 100644 index 000000000..7c43991db --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ReassemblyUtils.java @@ -0,0 +1,237 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.MIN_MTU_SIZE; +import static io.rsocket.core.StateUtils.isReassembling; +import static io.rsocket.core.StateUtils.isTerminated; +import static io.rsocket.core.StateUtils.markReassembled; +import static io.rsocket.core.StateUtils.markReassembling; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; + +class ReassemblyUtils { + static final String ILLEGAL_REASSEMBLED_PAYLOAD_SIZE = + "Reassembled payload size went out of allowed %s bytes"; + + @SuppressWarnings("ConstantConditions") + static void release(RequesterFrameHandler framesHolder, long state) { + if (isReassembling(state)) { + final CompositeByteBuf frames = framesHolder.getFrames(); + framesHolder.setFrames(null); + frames.release(); + } + } + + @SuppressWarnings({"ConstantConditions", "SynchronizationOnLocalVariableOrMethodParameter"}) + static void synchronizedRelease(RequesterFrameHandler framesHolder, long state) { + if (isReassembling(state)) { + final CompositeByteBuf frames = framesHolder.getFrames(); + framesHolder.setFrames(null); + + synchronized (frames) { + frames.release(); + } + } + } + + static void handleNextSupport( + AtomicLongFieldUpdater updater, + T instance, + Subscription subscription, + CoreSubscriber inboundSubscriber, + PayloadDecoder payloadDecoder, + ByteBufAllocator allocator, + int maxInboundPayloadSize, + ByteBuf frame, + boolean hasFollows, + boolean isLastPayload) { + + long state = updater.get(instance); + if (isTerminated(state)) { + return; + } + + if (!hasFollows && !isReassembling(state)) { + Payload payload; + try { + payload = payloadDecoder.apply(frame); + } catch (Throwable t) { + // sends cancel frame to prevent any further frames + subscription.cancel(); + // terminates downstream + inboundSubscriber.onError(t); + + return; + } + + instance.handlePayload(payload); + if (isLastPayload) { + instance.handleComplete(); + } + return; + } + + CompositeByteBuf frames = instance.getFrames(); + if (frames == null) { + frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), frame, maxInboundPayloadSize); + instance.setFrames(frames); + + long previousState = markReassembling(updater, instance); + if (isTerminated(previousState)) { + instance.setFrames(null); + frames.release(); + return; + } + } else { + try { + frames = ReassemblyUtils.addFollowingFrame(frames, frame, maxInboundPayloadSize); + } catch (IllegalStateException t) { + if (isTerminated(updater.get(instance))) { + return; + } + + instance.setFrames(null); + + // sends cancel frame to prevent any further frames + subscription.cancel(); + // terminates downstream + inboundSubscriber.onError(t); + + return; + } + } + + if (!hasFollows) { + long previousState = markReassembled(updater, instance); + if (isTerminated(previousState)) { + return; + } + + instance.setFrames(null); + + Payload payload; + try { + payload = payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frames); + + // sends cancel frame to prevent any further frames + subscription.cancel(); + // terminates downstream + inboundSubscriber.onError(t); + + return; + } + + instance.handlePayload(payload); + + if (isLastPayload) { + instance.handleComplete(); + } + } + } + + static CompositeByteBuf addFollowingFrame( + CompositeByteBuf frames, ByteBuf followingFrame, int maxInboundPayloadSize) { + int readableBytes = frames.readableBytes(); + if (readableBytes == 0) { + return frames.addComponent(true, followingFrame.retain()); + } else if (maxInboundPayloadSize != Integer.MAX_VALUE + && readableBytes + followingFrame.readableBytes() - FrameHeaderCodec.size() + > maxInboundPayloadSize) { + frames.release(); + throw new IllegalStateException( + String.format(ILLEGAL_REASSEMBLED_PAYLOAD_SIZE, maxInboundPayloadSize)); + } + + final boolean hasMetadata = FrameHeaderCodec.hasMetadata(followingFrame); + + // skip headers + followingFrame.skipBytes(FrameHeaderCodec.size()); + + // if has metadata, then we have to increase metadata length in containing frames + // CompositeByteBuf + if (hasMetadata) { + frames.markReaderIndex().skipBytes(FrameHeaderCodec.size()); + + final int nextMetadataLength = decodeLength(frames) + decodeLength(followingFrame); + + frames.resetReaderIndex(); + + frames.markWriterIndex(); + frames.writerIndex(FrameHeaderCodec.size()); + + encodeLength(frames, nextMetadataLength); + + frames.resetWriterIndex(); + } + + synchronized (frames) { + if (frames.refCnt() > 0) { + followingFrame.retain(); + return frames.addComponent(true, followingFrame); + } else { + throw new IllegalReferenceCountException(0); + } + } + } + + private static void encodeLength(final ByteBuf byteBuf, final int length) { + if ((length & ~FRAME_LENGTH_MASK) != 0) { + throw new IllegalArgumentException("Length is larger than 24 bits"); + } + // Write each byte separately in reverse order, this mean we can write 1 << 23 without + // overflowing. + byteBuf.writeByte(length >> 16); + byteBuf.writeByte(length >> 8); + byteBuf.writeByte(length); + } + + private static int decodeLength(final ByteBuf byteBuf) { + int length = (byteBuf.readByte() & 0xFF) << 16; + length |= (byteBuf.readByte() & 0xFF) << 8; + length |= byteBuf.readByte() & 0xFF; + return length; + } + + static int assertInboundPayloadSize(int inboundPayloadSize) { + if (inboundPayloadSize < MIN_MTU_SIZE) { + String msg = + String.format( + "The min allowed inboundPayloadSize size is %d bytes, provided: %d", + FrameLengthCodec.FRAME_LENGTH_MASK, inboundPayloadSize); + throw new IllegalArgumentException(msg); + } else { + return inboundPayloadSize; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java new file mode 100644 index 000000000..fd2193948 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java @@ -0,0 +1,549 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.ReassemblyUtils.handleNextSupport; +import static io.rsocket.core.SendUtils.DISCARD_CONTEXT; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.UnboundedProcessor; +import java.util.Objects; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.*; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +final class RequestChannelRequesterFlux extends Flux + implements RequesterFrameHandler, CoreSubscriber, Subscription, Scannable { + + final ByteBufAllocator allocator; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final UnboundedProcessor sendProcessor; + final PayloadDecoder payloadDecoder; + + final Publisher payloadsPublisher; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(RequestChannelRequesterFlux.class, "state"); + + int streamId; + + Context cachedContext; + + boolean isFirstPayload = true; + + CoreSubscriber inboundSubscriber; + Subscription outboundSubscription; + boolean inboundDone; + boolean outboundDone; + + CompositeByteBuf frames; + + RequestChannelRequesterFlux( + Publisher payloadsPublisher, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payloadsPublisher = payloadsPublisher; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.sendProcessor = requesterResponderSupport.getSendProcessor(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + Objects.requireNonNull(actual, "subscribe"); + + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + Operators.error( + actual, new IllegalStateException("RequestChannelFlux allows only a single Subscriber")); + return; + } + + this.inboundSubscriber = actual; + this.payloadsPublisher.subscribe(this); + } + + @Override + public void onSubscribe(Subscription outboundSubscription) { + if (Operators.validate(this.outboundSubscription, outboundSubscription)) { + this.outboundSubscription = outboundSubscription; + this.inboundSubscriber.onSubscribe(this); + } + } + + @Override + public final void request(long n) { + if (!Operators.validate(n)) { + return; + } + + long previousState = addRequestN(STATE, this, n); + if (isTerminated(previousState)) { + return; + } + + if (hasRequested(previousState)) { + if (isFirstFrameSent(previousState) + && !isMaxAllowedRequestN(extractRequestN(previousState))) { + final ByteBuf requestNFrame = RequestNFrameCodec.encode(this.allocator, this.streamId, n); + this.sendProcessor.onNext(requestNFrame); + } + return; + } + + // do first request + this.outboundSubscription.request(1); + } + + @Override + public void onNext(Payload p) { + if (this.outboundDone) { + p.release(); + return; + } + + if (this.isFirstPayload) { + this.isFirstPayload = false; + + long state = this.state; + if (isTerminated(state)) { + p.release(); + return; + } + sendFirstPayload(p, extractRequestN(state)); + } else { + sendFollowingPayload(p); + } + } + + void sendFirstPayload(Payload firstPayload, long initialRequestN) { + int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, firstPayload, true)) { + lazyTerminate(STATE, this); + + firstPayload.release(); + this.outboundSubscription.cancel(); + + this.inboundDone = true; + this.inboundSubscriber.onError( + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + + this.outboundSubscription.cancel(); + + this.inboundDone = true; + this.inboundSubscriber.onError(e); + return; + } + + final RequesterResponderSupport sm = this.requesterResponderSupport; + final UnboundedProcessor sender = this.sendProcessor; + final ByteBufAllocator allocator = this.allocator; + + final int streamId; + try { + streamId = sm.addAndGetNextStreamId(this); + this.streamId = streamId; + } catch (Throwable t) { + lazyTerminate(STATE, this); + + firstPayload.release(); + this.outboundSubscription.cancel(); + + this.inboundDone = true; + this.inboundSubscriber.onError(Exceptions.unwrap(t)); + return; + } + + try { + sendReleasingPayload( + streamId, + FrameType.REQUEST_CHANNEL, + initialRequestN, + mtu, + firstPayload, + sender, + allocator, + // TODO: Should be a different flag in case of the scalar + // source or if we know in advance upstream is mono + false); + } catch (Throwable e) { + lazyTerminate(STATE, this); + + sm.remove(streamId, this); + this.outboundSubscription.cancel(); + + this.inboundDone = true; + this.inboundSubscriber.onError(e); + return; + } + + long previousState = markFirstFrameSent(STATE, this); + if (isTerminated(previousState)) { + if (this.inboundDone) { + return; + } + + sm.remove(streamId, this); + + ReassemblyUtils.synchronizedRelease(this, previousState); + + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + sender.onNext(cancelFrame); + + return; + } + + if (isMaxAllowedRequestN(initialRequestN)) { + return; + } + + long requestN = extractRequestN(previousState); + if (isMaxAllowedRequestN(requestN)) { + final ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, streamId, requestN); + sender.onNext(requestNFrame); + return; + } + + if (requestN > initialRequestN) { + final ByteBuf requestNFrame = + RequestNFrameCodec.encode(allocator, streamId, requestN - initialRequestN); + sender.onNext(requestNFrame); + } + } + + final void sendFollowingPayload(Payload followingPayload) { + int streamId = this.streamId; + int mtu = this.mtu; + + try { + if (!isValid(mtu, this.maxFrameLength, followingPayload, true)) { + followingPayload.release(); + + this.cancel(); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + this.propagateErrorSafely(e); + return; + } + } catch (IllegalReferenceCountException e) { + this.cancel(); + + this.propagateErrorSafely(e); + + return; + } + + try { + sendReleasingPayload( + streamId, + + // TODO: Should be a different flag in case of the scalar + // source or if we know in advance upstream is mono + FrameType.NEXT, + mtu, + followingPayload, + this.sendProcessor, + allocator, + true); + } catch (Throwable e) { + this.cancel(); + + this.propagateErrorSafely(e); + } + } + + void propagateErrorSafely(Throwable e) { + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + if (!this.inboundDone) { + synchronized (this) { + if (!this.inboundDone) { + this.inboundDone = true; + this.inboundSubscriber.onError(e); + } else { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + } + } + } else { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + } + } + + @Override + public final void cancel() { + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + this.outboundSubscription.cancel(); + + if (!isFirstFrameSent(previousState)) { + // no need to send anything, since we have not started a stream yet (no logical wire) + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + ReassemblyUtils.synchronizedRelease(this, previousState); + + final ByteBuf cancelFrame = CancelFrameCodec.encode(this.allocator, streamId); + this.sendProcessor.onNext(cancelFrame); + } + + @Override + public void onError(Throwable t) { + if (this.outboundDone) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + this.outboundDone = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + if (!isFirstFrameSent(previousState)) { + // first signal, thus, just propagates error to actual subscriber + this.inboundSubscriber.onError(t); + return; + } + + ReassemblyUtils.synchronizedRelease(this, previousState); + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + // propagates error to remote responder + final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); + this.sendProcessor.onNext(errorFrame); + + if (!isInboundTerminated(previousState)) { + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + synchronized (this) { + this.inboundDone = true; + this.inboundSubscriber.onError(t); + } + } else { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + } + } + + @Override + public void onComplete() { + if (this.outboundDone) { + return; + } + + this.outboundDone = true; + + long previousState = markOutboundTerminated(STATE, this, true); + if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + return; + } + + if (!isFirstFrameSent(previousState)) { + // first signal, thus, just propagates error to actual subscriber + this.inboundSubscriber.onError(new CancellationException("Empty Source")); + return; + } + + final int streamId = this.streamId; + + if (isInboundTerminated(previousState)) { + this.requesterResponderSupport.remove(streamId, this); + } + + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); + this.sendProcessor.onNext(completeFrame); + } + + @Override + public final void handleComplete() { + if (this.inboundDone) { + return; + } + + this.inboundDone = true; + + long previousState = markInboundTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + if (isOutboundTerminated(previousState)) { + this.requesterResponderSupport.remove(this.streamId, this); + } + + this.inboundSubscriber.onComplete(); + } + + @Override + public final void handleError(Throwable cause) { + if (this.inboundDone) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + this.inboundDone = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState) || isInboundTerminated(previousState)) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + ReassemblyUtils.release(this, previousState); + + final int streamId = this.streamId; + + this.requesterResponderSupport.remove(streamId, this); + + this.outboundSubscription.cancel(); + this.inboundSubscriber.onError(cause); + } + + @Override + public final void handlePayload(Payload value) { + synchronized (this) { + if (this.inboundDone) { + value.release(); + return; + } + + this.inboundSubscriber.onNext(value); + } + } + + @Override + public void handleRequestN(long n) { + this.outboundSubscription.request(n); + } + + @Override + public void handleCancel() { + if (this.outboundDone) { + return; + } + + long previousState = markOutboundTerminated(STATE, this, false); + if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + return; + } + + if (isInboundTerminated(previousState)) { + this.requesterResponderSupport.remove(this.streamId, this); + } + + this.outboundSubscription.cancel(); + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + handleNextSupport( + STATE, + this, + this, + this.inboundSubscriber, + this.payloadDecoder, + this.allocator, + this.maxInboundPayloadSize, + frame, + hasFollows, + isLastPayload); + } + + @Override + @NonNull + public Context currentContext() { + long state = this.state; + + if (isSubscribedOrTerminated(state)) { + Context contextWithDiscard = this.inboundSubscriber.currentContext().putAll(DISCARD_CONTEXT); + cachedContext = contextWithDiscard; + return contextWithDiscard; + } + + return Context.empty(); + } + + @Override + public CompositeByteBuf getFrames() { + return this.frames; + } + + @Override + public void setFrames(CompositeByteBuf byteBuf) { + this.frames = byteBuf; + } + + @Override + @Nullable + public Object scanUnsafe(Attr key) { + // touch guard + long state = this.state; + + if (key == Attr.TERMINATED) return isTerminated(state); + if (key == Attr.REQUESTED_FROM_DOWNSTREAM) return state; + + return null; + } + + @Override + @NonNull + public String stepName() { + return "source(RequestChannelFlux)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java new file mode 100644 index 000000000..1f022322c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java @@ -0,0 +1,726 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; +import static reactor.core.Exceptions.TERMINATED; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.CanceledException; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.UnboundedProcessor; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; + +final class RequestChannelResponderSubscriber extends Flux + implements ResponderFrameHandler, Subscription, CoreSubscriber { + + static final Logger logger = LoggerFactory.getLogger(RequestChannelResponderSubscriber.class); + + final int streamId; + final ByteBufAllocator allocator; + final PayloadDecoder payloadDecoder; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final UnboundedProcessor sendProcessor; + final long firstRequest; + + final RSocket handler; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(RequestChannelResponderSubscriber.class, "state"); + + Payload firstPayload; + + Subscription outboundSubscription; + CoreSubscriber inboundSubscriber; + + CompositeByteBuf frames; + + volatile Throwable inboundError; + static final AtomicReferenceFieldUpdater + INBOUND_ERROR = + AtomicReferenceFieldUpdater.newUpdater( + RequestChannelResponderSubscriber.class, Throwable.class, "inboundError"); + + boolean inboundDone; + boolean outboundDone; + + public RequestChannelResponderSubscriber( + int streamId, + long firstRequestN, + ByteBuf firstFrame, + RequesterResponderSupport requesterResponderSupport, + RSocket handler) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.sendProcessor = requesterResponderSupport.getSendProcessor(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.handler = handler; + this.firstRequest = firstRequestN; + + this.frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), firstFrame, maxInboundPayloadSize); + STATE.lazySet(this, REASSEMBLING_FLAG); + } + + public RequestChannelResponderSubscriber( + int streamId, + long firstRequestN, + Payload firstPayload, + RequesterResponderSupport requesterResponderSupport) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.sendProcessor = requesterResponderSupport.getSendProcessor(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.firstRequest = firstRequestN; + this.firstPayload = firstPayload; + + this.handler = null; + this.frames = null; + } + + @Override + // subscriber from the requestChannel method + public void subscribe(CoreSubscriber actual) { + + long previousState = markSubscribed(STATE, this); + if (isTerminated(previousState)) { + Throwable t = Exceptions.terminate(INBOUND_ERROR, this); + if (t != TERMINATED) { + //noinspection ConstantConditions + Operators.error(actual, t); + } else { + Operators.error( + actual, + new CancellationException("RequestChannelSubscriber has already been terminated")); + } + return; + } + + if (isSubscribed(previousState)) { + Operators.error( + actual, new IllegalStateException("RequestChannelSubscriber allows only one Subscriber")); + return; + } + + this.inboundSubscriber = actual; + // sends sender as a subscription since every request|cancel signal should be encoded to + // requestNFrame|cancelFrame + actual.onSubscribe(this); + } + + @Override + // subscription to the outbound + public void onSubscribe(Subscription outboundSubscription) { + if (Operators.validate(this.outboundSubscription, outboundSubscription)) { + this.outboundSubscription = outboundSubscription; + outboundSubscription.request(this.firstRequest); + } + } + + @Override + public void request(long n) { + if (!Operators.validate(n)) { + return; + } + + long previousState = StateUtils.addRequestN(STATE, this, n); + if (isTerminated(previousState)) { + // full termination can be the result of both sides completion / cancelFrame / remote or local + // error + // therefore, we need to check inbound error value, to see what should be done + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError == TERMINATED) { + // means inbound was already terminated + return; + } + + if (inboundError != null || this.inboundDone) { + final CoreSubscriber inboundSubscriber = this.inboundSubscriber; + + Payload firstPayload = this.firstPayload; + if (firstPayload != null) { + this.firstPayload = null; + inboundSubscriber.onNext(firstPayload); + } + + if (inboundError != null) { + inboundSubscriber.onError(inboundError); + } else { + inboundSubscriber.onComplete(); + } + } + return; + } + + if (isInboundTerminated(previousState)) { + // inbound only can be terminated in case of cancellation or complete frame + if (!hasRequested(previousState) && !isFirstFrameSent(previousState) && this.inboundDone) { + final CoreSubscriber inboundSubscriber = this.inboundSubscriber; + + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + + inboundSubscriber.onNext(firstPayload); + inboundSubscriber.onComplete(); + + markFirstFrameSent(STATE, this); + } + return; + } + + if (hasRequested(previousState)) { + if (isFirstFrameSent(previousState) + && !isMaxAllowedRequestN(StateUtils.extractRequestN(previousState))) { + final ByteBuf requestNFrame = RequestNFrameCodec.encode(this.allocator, this.streamId, n); + this.sendProcessor.onNext(requestNFrame); + } + return; + } + + final CoreSubscriber inboundSubscriber = this.inboundSubscriber; + + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + inboundSubscriber.onNext(firstPayload); + + previousState = markFirstFrameSent(STATE, this); + if (isTerminated(previousState)) { + // full termination can be the result of both sides completion / cancelFrame / remote or local + // error + // therefore, we need to check inbound error value, to see what should be done + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError == TERMINATED) { + // means inbound was already terminated + return; + } + + if (inboundError != null) { + inboundSubscriber.onError(inboundError); + } else if (this.inboundDone) { + inboundSubscriber.onComplete(); + } + return; + } + + if (isInboundTerminated(previousState)) { + // inbound only can be terminated in case of cancellation or complete frame + if (this.inboundDone) { + inboundSubscriber.onComplete(); + } + return; + } + + long requestN = StateUtils.extractRequestN(previousState); + if (isMaxAllowedRequestN(requestN)) { + final ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, streamId, requestN); + this.sendProcessor.onNext(requestNFrame); + } else { + long firstRequestN = requestN - 1; + if (firstRequestN > 0) { + final ByteBuf requestNFrame = + RequestNFrameCodec.encode(this.allocator, this.streamId, firstRequestN); + this.sendProcessor.onNext(requestNFrame); + } + } + } + + @Override + // inbound cancellation + public void cancel() { + long previousState = markInboundTerminated(STATE, this); + if (isTerminated(previousState) || isInboundTerminated(previousState)) { + return; + } + + if (!isFirstFrameSent(previousState) && !hasRequested(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } + + final int streamId = this.streamId; + + if (isOutboundTerminated(previousState)) { + this.requesterResponderSupport.remove(streamId, this); + } + + final ByteBuf cancelFrame = CancelFrameCodec.encode(this.allocator, streamId); + this.sendProcessor.onNext(cancelFrame); + } + + @Override + public final void handleCancel() { + Subscription outboundSubscription = this.outboundSubscription; + if (outboundSubscription == null) { + // if subscription is null, it means that streams has not yet reassembled all the fragments + // and fragmentation of the first frame was cancelled before + lazyTerminate(STATE, this); + + this.requesterResponderSupport.remove(this.streamId, this); + + final CompositeByteBuf frames = this.frames; + this.frames = null; + frames.release(); + return; + } + + this.tryTerminate(true); + } + + final long tryTerminate(boolean isFromInbound) { + Exceptions.addThrowable( + INBOUND_ERROR, this, new CancellationException("Inbound has been canceled")); + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return previousState; + } + + this.requesterResponderSupport.remove(this.streamId, this); + + if (isReassembling(previousState)) { + final CompositeByteBuf frames = this.frames; + this.frames = null; + if (isFromInbound) { + frames.release(); + } else { + synchronized (frames) { + frames.release(); + } + } + } + + this.outboundSubscription.cancel(); + + if (!isSubscribed(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } else if (isFirstFrameSent(previousState) && !isInboundTerminated(previousState)) { + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError != TERMINATED) { + if (isFromInbound) { + this.inboundDone = true; + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } else { + synchronized (this) { + this.inboundDone = true; + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } + } + } + } + + return previousState; + } + + final void handlePayload(Payload p) { + synchronized (this) { + if (this.inboundDone) { + // payload from network so it has refCnt > 0 + p.release(); + return; + } + + this.inboundSubscriber.onNext(p); + } + } + + @Override + public final void handleError(Throwable t) { + if (this.inboundDone) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + this.inboundDone = true; + boolean wasThrowableAdded = Exceptions.addThrowable(INBOUND_ERROR, this, t); + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + if (!wasThrowableAdded) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + } + return; + } + + this.requesterResponderSupport.remove(this.streamId, this); + + if (isReassembling(previousState)) { + final CompositeByteBuf frames = this.frames; + this.frames = null; + frames.release(); + } + + if (!isSubscribed(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } else if (isFirstFrameSent(previousState) && !isInboundTerminated(previousState)) { + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError != TERMINATED) { + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } + } + + // this is downstream subscription so need to cancel it just in case error signal has not + // reached it + // needs for disconnected upstream and downstream case + this.outboundSubscription.cancel(); + } + + @Override + public void handleComplete() { + if (this.inboundDone) { + return; + } + + this.inboundDone = true; + + long previousState = markInboundTerminated(STATE, this); + + if (isOutboundTerminated(previousState)) { + this.requesterResponderSupport.remove(this.streamId, this); + } + + if (isFirstFrameSent(previousState)) { + this.inboundSubscriber.onComplete(); + } + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + long state = this.state; + if (isTerminated(state)) { + return; + } + + if (!hasFollows && !isReassembling(state)) { + Payload payload; + try { + payload = this.payloadDecoder.apply(frame); + } catch (Throwable t) { + long previousState = this.tryTerminate(true); + if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + this.outboundDone = true; + // send error to terminate interaction + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, this.streamId, new CanceledException(t.getMessage())); + this.sendProcessor.onNext(errorFrame); + + return; + } + + this.handlePayload(payload); + if (isLastPayload) { + this.handleComplete(); + } + return; + } + + CompositeByteBuf frames = this.frames; + if (frames == null) { + frames = + ReassemblyUtils.addFollowingFrame( + this.allocator.compositeBuffer(), frame, this.maxInboundPayloadSize); + this.frames = frames; + + long previousState = markReassembling(STATE, this); + if (isTerminated(previousState)) { + this.frames = null; + frames.release(); + return; + } + } else { + try { + frames = ReassemblyUtils.addFollowingFrame(frames, frame, this.maxInboundPayloadSize); + } catch (IllegalReferenceCountException e) { + if (isTerminated(this.state)) { + return; + } + + long previousState = this.tryTerminate(true); + if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + this.frames = null; + + this.outboundDone = true; + // send error to terminate interaction + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + this.streamId, + new CanceledException("Failed to reassemble payload. Cause: " + e.getMessage())); + this.sendProcessor.onNext(errorFrame); + + return; + } + } + + if (!hasFollows) { + long previousState = markReassembled(STATE, this); + if (isTerminated(previousState)) { + return; + } + + this.frames = null; + + Payload payload; + try { + payload = this.payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frames); + + previousState = this.tryTerminate(true); + if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + // send error to terminate interaction + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + this.streamId, + new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); + this.sendProcessor.onNext(errorFrame); + + return; + } + + if (this.outboundSubscription == null) { + this.firstPayload = payload; + Flux source = this.handler.requestChannel(this); + source.subscribe(this); + } else { + this.handlePayload(payload); + } + + if (isLastPayload) { + this.handleComplete(); + } + } + } + + @Override + public void onNext(Payload p) { + if (this.outboundDone) { + ReferenceCountUtil.safeRelease(p); + return; + } + + final int streamId = this.streamId; + final UnboundedProcessor sender = this.sendProcessor; + final ByteBufAllocator allocator = this.allocator; + + if (p == null) { + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(allocator, streamId); + sender.onNext(completeFrame); + return; + } + + final int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + p.release(); + + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + long previousState = this.tryTerminate(false); + if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + Operators.onErrorDropped( + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)), + this.inboundSubscriber.currentContext()); + return; + } + + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + sender.onNext(errorFrame); + return; + } + } catch (IllegalReferenceCountException e) { + + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + long previousState = this.tryTerminate(false); + if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException("Failed to validate payload. Cause:" + e.getMessage())); + sender.onNext(errorFrame); + return; + } + + try { + sendReleasingPayload(streamId, FrameType.NEXT, mtu, p, sender, allocator, false); + } catch (Throwable t) { + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + this.tryTerminate(false); + } + } + + @Override + public void onError(Throwable t) { + if (this.outboundDone) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + boolean wasThrowableAdded = + Exceptions.addThrowable( + INBOUND_ERROR, + this, + new CancellationException("Outbound has terminated with an error")); + this.outboundDone = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + final int streamId = this.streamId; + + this.requesterResponderSupport.remove(streamId, this); + + if (isReassembling(previousState)) { + final CompositeByteBuf frames = this.frames; + this.frames = null; + synchronized (frames) { + frames.release(); + } + } + + if (!isFirstFrameSent(previousState)) { + if (!hasRequested(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } + } + + if (wasThrowableAdded && !isInboundTerminated(previousState)) { + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError != TERMINATED) { + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + synchronized (this) { + this.inboundDone = true; + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } + } + } + + final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); + this.sendProcessor.onNext(errorFrame); + } + + @Override + public void onComplete() { + if (this.outboundDone) { + return; + } + + this.outboundDone = true; + + long previousState = markOutboundTerminated(STATE, this, false); + if (isTerminated(previousState)) { + return; + } + + final int streamId = this.streamId; + + if (isInboundTerminated(previousState)) { + this.requesterResponderSupport.remove(streamId, this); + } + + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); + this.sendProcessor.onNext(completeFrame); + } + + @Override + public final void handleRequestN(long n) { + this.outboundSubscription.request(n); + } + + @Override + public Context currentContext() { + return SendUtils.DISCARD_CONTEXT; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestOperator.java b/rsocket-core/src/main/java/io/rsocket/core/RequestOperator.java deleted file mode 100644 index 6123b0492..000000000 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestOperator.java +++ /dev/null @@ -1,189 +0,0 @@ -package io.rsocket.core; - -import io.rsocket.Payload; -import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; -import org.reactivestreams.Subscription; -import reactor.core.CoreSubscriber; -import reactor.core.Fuseable; -import reactor.core.publisher.Operators; -import reactor.core.publisher.SignalType; -import reactor.util.context.Context; - -/** - * This is a support class for handling of request input, intended for use with {@link - * Operators#lift}. It ensures serial execution of cancellation vs first request signals and also - * provides hooks for separate handling of first vs subsequent {@link Subscription#request} - * invocations. - */ -abstract class RequestOperator - implements CoreSubscriber, Fuseable.QueueSubscription { - - final CoreSubscriber actual; - - Subscription s; - Fuseable.QueueSubscription qs; - - int streamId; - boolean firstRequest = true; - - volatile int wip; - static final AtomicIntegerFieldUpdater WIP = - AtomicIntegerFieldUpdater.newUpdater(RequestOperator.class, "wip"); - - RequestOperator(CoreSubscriber actual) { - this.actual = actual; - } - - /** - * Optional hook executed exactly once on the first {@link Subscription#request) invocation - * and right after the {@link Subscription#request} was propagated to the upstream subscription. - * - *

Note: this hook may not be invoked if cancellation happened before this invocation - */ - void hookOnFirstRequest(long n) {} - - /** - * Optional hook executed after the {@link Subscription#request} was propagated to the upstream - * subscription and excludes the first {@link Subscription#request} invocation. - */ - void hookOnRemainingRequests(long n) {} - - /** Optional hook executed after this {@link Subscription} cancelling. */ - void hookOnCancel() {} - - /** - * Optional hook executed after {@link org.reactivestreams.Subscriber} termination events - * (onError, onComplete). - * - * @param signalType the type of termination event that triggered the hook ({@link - * SignalType#ON_ERROR} or {@link SignalType#ON_COMPLETE}) - */ - void hookOnTerminal(SignalType signalType) {} - - @Override - public Context currentContext() { - return actual.currentContext(); - } - - @Override - public void request(long n) { - this.s.request(n); - if (!firstRequest) { - try { - this.hookOnRemainingRequests(n); - } catch (Throwable throwable) { - onError(throwable); - } - return; - } - - if (WIP.getAndIncrement(this) != 0) { - return; - } - - this.firstRequest = false; - int missed = 1; - - boolean firstLoop = true; - for (; ; ) { - if (firstLoop) { - firstLoop = false; - try { - this.hookOnFirstRequest(n); - } catch (Throwable throwable) { - onError(throwable); - return; - } - } else { - try { - this.hookOnCancel(); - } catch (Throwable throwable) { - onError(throwable); - } - return; - } - - missed = WIP.addAndGet(this, -missed); - if (missed == 0) { - return; - } - } - } - - @Override - public void cancel() { - this.s.cancel(); - - if (WIP.getAndIncrement(this) != 0) { - return; - } - - hookOnCancel(); - } - - @Override - @SuppressWarnings("unchecked") - public void onSubscribe(Subscription s) { - if (Operators.validate(this.s, s)) { - this.s = s; - if (s instanceof Fuseable.QueueSubscription) { - this.qs = (Fuseable.QueueSubscription) s; - } - this.actual.onSubscribe(this); - } - } - - @Override - public void onNext(Payload t) { - this.actual.onNext(t); - } - - @Override - public void onError(Throwable t) { - this.actual.onError(t); - try { - this.hookOnTerminal(SignalType.ON_ERROR); - } catch (Throwable throwable) { - Operators.onErrorDropped(throwable, currentContext()); - } - } - - @Override - public void onComplete() { - this.actual.onComplete(); - try { - this.hookOnTerminal(SignalType.ON_COMPLETE); - } catch (Throwable throwable) { - Operators.onErrorDropped(throwable, currentContext()); - } - } - - @Override - public int requestFusion(int requestedMode) { - if (this.qs != null) { - return this.qs.requestFusion(requestedMode); - } else { - return Fuseable.NONE; - } - } - - @Override - public Payload poll() { - return this.qs.poll(); - } - - @Override - public int size() { - return this.qs.size(); - } - - @Override - public boolean isEmpty() { - return this.qs.isEmpty(); - } - - @Override - public void clear() { - this.qs.clear(); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java new file mode 100644 index 000000000..ce1f40355 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java @@ -0,0 +1,290 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.ReassemblyUtils.handleNextSupport; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.UnboundedProcessor; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class RequestResponseRequesterMono extends Mono + implements RequesterFrameHandler, Subscription, Scannable { + + final ByteBufAllocator allocator; + final Payload payload; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final UnboundedProcessor sendProcessor; + final PayloadDecoder payloadDecoder; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(RequestResponseRequesterMono.class, "state"); + + int streamId; + CoreSubscriber actual; + CompositeByteBuf frames; + boolean done; + + RequestResponseRequesterMono( + Payload payload, RequesterResponderSupport requesterResponderSupport) { + + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.sendProcessor = requesterResponderSupport.getSendProcessor(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + Operators.error( + actual, new IllegalStateException("RequestResponseMono allows only a single Subscriber")); + return; + } + + final Payload p = this.payload; + try { + if (!isValid(this.mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + Operators.error( + actual, + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + p.release(); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + Operators.error(actual, e); + return; + } + + this.actual = actual; + actual.onSubscribe(this); + } + + @Override + public final void request(long n) { + if (!Operators.validate(n)) { + return; + } + + long previousState = addRequestN(STATE, this, n); + if (isTerminated(previousState) || hasRequested(previousState)) { + return; + } + + sendFirstPayload(this.payload, n); + } + + void sendFirstPayload(Payload payload, long initialRequestN) { + + final RequesterResponderSupport sm = this.requesterResponderSupport; + final UnboundedProcessor sender = this.sendProcessor; + final ByteBufAllocator allocator = this.allocator; + + final int streamId; + try { + streamId = sm.addAndGetNextStreamId(this); + this.streamId = streamId; + } catch (Throwable t) { + this.done = true; + lazyTerminate(STATE, this); + payload.release(); + this.actual.onError(Exceptions.unwrap(t)); + return; + } + + try { + sendReleasingPayload( + streamId, FrameType.REQUEST_RESPONSE, this.mtu, payload, sender, allocator, true); + } catch (Throwable e) { + this.done = true; + lazyTerminate(STATE, this); + + sm.remove(streamId, this); + + this.actual.onError(e); + return; + } + + long previousState = markFirstFrameSent(STATE, this); + if (isTerminated(previousState)) { + if (this.done) { + return; + } + + sm.remove(streamId, this); + + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + sender.onNext(cancelFrame); + } + } + + @Override + public final void cancel() { + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + if (isFirstFrameSent(previousState)) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + ReassemblyUtils.synchronizedRelease(this, previousState); + + this.sendProcessor.onNext(CancelFrameCodec.encode(this.allocator, streamId)); + } else if (!hasRequested(previousState)) { + this.payload.release(); + } + } + + @Override + public final void handlePayload(Payload value) { + if (this.done) { + value.release(); + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + value.release(); + return; + } + + final CoreSubscriber a = this.actual; + + this.requesterResponderSupport.remove(this.streamId, this); + + a.onNext(value); + a.onComplete(); + } + + @Override + public final void handleComplete() { + if (this.done) { + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + this.requesterResponderSupport.remove(this.streamId, this); + + this.actual.onComplete(); + } + + @Override + public final void handleError(Throwable cause) { + if (this.done) { + Operators.onErrorDropped(cause, this.actual.currentContext()); + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.actual.currentContext()); + return; + } + + ReassemblyUtils.synchronizedRelease(this, previousState); + + this.requesterResponderSupport.remove(this.streamId, this); + + this.actual.onError(cause); + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + handleNextSupport( + STATE, + this, + this, + this.actual, + this.payloadDecoder, + this.allocator, + this.maxInboundPayloadSize, + frame, + hasFollows, + isLastPayload); + } + + @Override + public CompositeByteBuf getFrames() { + return this.frames; + } + + @Override + public void setFrames(CompositeByteBuf byteBuf) { + this.frames = byteBuf; + } + + @Override + @Nullable + public Object scanUnsafe(Attr key) { + // touch guard + long state = this.state; + + if (key == Attr.TERMINATED) return isTerminated(state); + if (key == Attr.PREFETCH) return 0; + + return null; + } + + @Override + @NonNull + public String stepName() { + return "source(RequestResponseMono)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java new file mode 100644 index 000000000..3a5afdb96 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java @@ -0,0 +1,244 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.CanceledException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.UnboundedProcessor; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +final class RequestResponseResponderSubscriber + implements ResponderFrameHandler, CoreSubscriber { + + static final Logger logger = LoggerFactory.getLogger(RequestResponseResponderSubscriber.class); + + final int streamId; + final ByteBufAllocator allocator; + final PayloadDecoder payloadDecoder; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final UnboundedProcessor sendProcessor; + + final RSocket handler; + + CompositeByteBuf frames; + + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater( + RequestResponseResponderSubscriber.class, Subscription.class, "s"); + + public RequestResponseResponderSubscriber( + int streamId, + ByteBuf firstFrame, + RequesterResponderSupport requesterResponderSupport, + RSocket handler) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.sendProcessor = requesterResponderSupport.getSendProcessor(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.handler = handler; + this.frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), firstFrame, maxInboundPayloadSize); + } + + public RequestResponseResponderSubscriber( + int streamId, RequesterResponderSupport requesterResponderSupport) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.sendProcessor = requesterResponderSupport.getSendProcessor(); + + this.payloadDecoder = null; + this.handler = null; + this.frames = null; + } + + @Override + public void onSubscribe(Subscription subscription) { + if (Operators.validate(this.s, subscription)) { + S.lazySet(this, subscription); + subscription.request(Long.MAX_VALUE); + } + } + + @Override + public void onNext(@Nullable Payload p) { + if (!Operators.terminate(S, this)) { + if (p != null) { + p.release(); + } + return; + } + + final int streamId = this.streamId; + final UnboundedProcessor sender = this.sendProcessor; + final ByteBufAllocator allocator = this.allocator; + + this.requesterResponderSupport.remove(streamId, this); + + if (p == null) { + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(allocator, streamId); + sender.onNext(completeFrame); + return; + } + + final int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + p.release(); + + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + sender.onNext(errorFrame); + return; + } + } catch (IllegalReferenceCountException e) { + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException("Failed to validate payload. Cause" + e.getMessage())); + sender.onNext(errorFrame); + return; + } + + try { + sendReleasingPayload(streamId, FrameType.NEXT_COMPLETE, mtu, p, sender, allocator, false); + } catch (Throwable ignored) { + } + } + + @Override + public void onError(Throwable t) { + if (S.getAndSet(this, Operators.cancelledSubscription()) == Operators.cancelledSubscription()) { + logger.debug("Dropped error", t); + return; + } + + final int streamId = this.streamId; + + this.requesterResponderSupport.remove(streamId, this); + + final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); + this.sendProcessor.onNext(errorFrame); + } + + @Override + public void onComplete() { + onNext(null); + } + + @Override + public void handleCancel() { + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription()) { + return; + } + + if (currentSubscription == null) { + // if subscription is null, it means that streams has not yet reassembled all the fragments + // and fragmentation of the first frame was cancelled before + S.lazySet(this, Operators.cancelledSubscription()); + + this.requesterResponderSupport.remove(this.streamId, this); + + final CompositeByteBuf frames = this.frames; + this.frames = null; + frames.release(); + + return; + } + + if (!S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + return; + } + + this.requesterResponderSupport.remove(this.streamId, this); + + currentSubscription.cancel(); + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + final CompositeByteBuf frames = + ReassemblyUtils.addFollowingFrame(this.frames, frame, this.maxInboundPayloadSize); + + if (!hasFollows) { + this.frames = null; + Payload payload; + try { + payload = this.payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frames); + logger.debug("Reassembly has failed", t); + // sends error frame from the responder side to tell that something went wrong + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + this.streamId, + new CanceledException("Failed to reassemble payload. Cause" + t.getMessage())); + this.sendProcessor.onNext(errorFrame); + return; + } + + final Mono source = this.handler.requestResponse(payload); + source.subscribe(this); + } + } + + @Override + public Context currentContext() { + return SendUtils.DISCARD_CONTEXT; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java new file mode 100644 index 000000000..8f6b1501f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java @@ -0,0 +1,313 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.ReassemblyUtils.handleNextSupport; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.UnboundedProcessor; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class RequestStreamRequesterFlux extends Flux + implements RequesterFrameHandler, Subscription, Scannable { + + final ByteBufAllocator allocator; + final Payload payload; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final UnboundedProcessor sendProcessor; + final PayloadDecoder payloadDecoder; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(RequestStreamRequesterFlux.class, "state"); + + int streamId; + CoreSubscriber inboundSubscriber; + CompositeByteBuf frames; + boolean done; + + RequestStreamRequesterFlux(Payload payload, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.sendProcessor = requesterResponderSupport.getSendProcessor(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + Operators.error( + actual, new IllegalStateException("RequestStreamFlux allows only a single Subscriber")); + return; + } + + final Payload p = this.payload; + try { + if (!isValid(this.mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + Operators.error( + actual, + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + p.release(); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + Operators.error(actual, e); + return; + } + + this.inboundSubscriber = actual; + actual.onSubscribe(this); + } + + @Override + public final void request(long n) { + if (!Operators.validate(n)) { + return; + } + + long previousState = addRequestN(STATE, this, n); + if (isTerminated(previousState)) { + return; + } + + if (hasRequested(previousState)) { + if (isFirstFrameSent(previousState) + && !isMaxAllowedRequestN(extractRequestN(previousState))) { + final ByteBuf requestNFrame = RequestNFrameCodec.encode(this.allocator, this.streamId, n); + this.sendProcessor.onNext(requestNFrame); + } + return; + } + + sendFirstPayload(this.payload, n); + } + + void sendFirstPayload(Payload payload, long initialRequestN) { + + final RequesterResponderSupport sm = this.requesterResponderSupport; + final UnboundedProcessor sender = this.sendProcessor; + final ByteBufAllocator allocator = this.allocator; + + final int streamId; + try { + streamId = sm.addAndGetNextStreamId(this); + this.streamId = streamId; + } catch (Throwable t) { + this.done = true; + lazyTerminate(STATE, this); + + payload.release(); + + this.inboundSubscriber.onError(Exceptions.unwrap(t)); + return; + } + + try { + sendReleasingPayload( + streamId, + FrameType.REQUEST_STREAM, + initialRequestN, + this.mtu, + payload, + sender, + allocator, + false); + } catch (Throwable e) { + this.done = true; + lazyTerminate(STATE, this); + + sm.remove(streamId, this); + + this.inboundSubscriber.onError(e); + return; + } + + long previousState = markFirstFrameSent(STATE, this); + if (isTerminated(previousState)) { + if (this.done) { + return; + } + + sm.remove(streamId, this); + + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + sender.onNext(cancelFrame); + + return; + } + + if (isMaxAllowedRequestN(initialRequestN)) { + return; + } + + long requestN = extractRequestN(previousState); + if (isMaxAllowedRequestN(requestN)) { + final ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, streamId, requestN); + sender.onNext(requestNFrame); + return; + } + + if (requestN > initialRequestN) { + final ByteBuf requestNFrame = + RequestNFrameCodec.encode(allocator, streamId, requestN - initialRequestN); + sender.onNext(requestNFrame); + } + } + + @Override + public final void cancel() { + final long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + if (isFirstFrameSent(previousState)) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + ReassemblyUtils.synchronizedRelease(this, previousState); + + this.sendProcessor.onNext(CancelFrameCodec.encode(this.allocator, streamId)); + } else if (!hasRequested(previousState)) { + // no need to send anything, since the first request has not happened + this.payload.release(); + } + } + + @Override + public final void handlePayload(Payload p) { + if (this.done) { + p.release(); + return; + } + + this.inboundSubscriber.onNext(p); + } + + @Override + public final void handleComplete() { + if (this.done) { + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + this.requesterResponderSupport.remove(this.streamId, this); + + this.inboundSubscriber.onComplete(); + } + + @Override + public final void handleError(Throwable cause) { + if (this.done) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + this.requesterResponderSupport.remove(this.streamId, this); + + ReassemblyUtils.synchronizedRelease(this, previousState); + + this.inboundSubscriber.onError(cause); + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + handleNextSupport( + STATE, + this, + this, + this.inboundSubscriber, + this.payloadDecoder, + this.allocator, + this.maxInboundPayloadSize, + frame, + hasFollows, + isLastPayload); + } + + @Override + public CompositeByteBuf getFrames() { + return this.frames; + } + + @Override + public void setFrames(CompositeByteBuf byteBuf) { + this.frames = byteBuf; + } + + @Override + @Nullable + public Object scanUnsafe(Attr key) { + // touch guard + long state = this.state; + + if (key == Attr.TERMINATED) return isTerminated(state); + if (key == Attr.REQUESTED_FROM_DOWNSTREAM) return extractRequestN(state); + + return null; + } + + @Override + @NonNull + public String stepName() { + return "source(RequestStreamFlux)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java new file mode 100644 index 000000000..5d85ba2fa --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java @@ -0,0 +1,287 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.CanceledException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.UnboundedProcessor; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; + +final class RequestStreamResponderSubscriber + implements ResponderFrameHandler, CoreSubscriber { + + static final Logger logger = LoggerFactory.getLogger(RequestStreamResponderSubscriber.class); + + final int streamId; + final long firstRequest; + final ByteBufAllocator allocator; + final PayloadDecoder payloadDecoder; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final UnboundedProcessor sendProcessor; + + final RSocket handler; + + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater( + RequestStreamResponderSubscriber.class, Subscription.class, "s"); + + CompositeByteBuf frames; + boolean done; + + public RequestStreamResponderSubscriber( + int streamId, + long firstRequest, + ByteBuf firstFrame, + RequesterResponderSupport requesterResponderSupport, + RSocket handler) { + this.streamId = streamId; + this.firstRequest = firstRequest; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.sendProcessor = requesterResponderSupport.getSendProcessor(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.handler = handler; + this.frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), firstFrame, maxInboundPayloadSize); + } + + public RequestStreamResponderSubscriber( + int streamId, long firstRequest, RequesterResponderSupport requesterResponderSupport) { + this.streamId = streamId; + this.firstRequest = firstRequest; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.sendProcessor = requesterResponderSupport.getSendProcessor(); + + this.payloadDecoder = null; + this.handler = null; + this.frames = null; + } + + @Override + public void onSubscribe(Subscription subscription) { + if (Operators.validate(this.s, subscription)) { + final long firstRequest = this.firstRequest; + S.lazySet(this, subscription); + subscription.request(firstRequest); + } + } + + @Override + public void onNext(Payload p) { + if (this.done) { + ReferenceCountUtil.safeRelease(p); + return; + } + + final int streamId = this.streamId; + final UnboundedProcessor sender = this.sendProcessor; + final ByteBufAllocator allocator = this.allocator; + + if (p == null) { + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(allocator, streamId); + sender.onNext(completeFrame); + return; + } + + final int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + p.release(); + + this.handleCancel(); + + this.done = true; + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + sender.onNext(errorFrame); + return; + } + } catch (IllegalReferenceCountException e) { + this.handleCancel(); + this.done = true; + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException("Failed to validate payload. Cause" + e.getMessage())); + sender.onNext(errorFrame); + return; + } + + try { + sendReleasingPayload(streamId, FrameType.NEXT, mtu, p, sender, allocator, false); + } catch (Throwable t) { + this.handleCancel(); + this.done = true; + } + } + + @Override + public void onError(Throwable t) { + if (this.done) { + logger.debug("Dropped error", t); + return; + } + + this.done = true; + + if (S.getAndSet(this, Operators.cancelledSubscription()) == Operators.cancelledSubscription()) { + logger.debug("Dropped error", t); + return; + } + + final CompositeByteBuf frames = this.frames; + if (frames != null && frames.refCnt() > 0) { + frames.release(); + } + + final int streamId = this.streamId; + + this.requesterResponderSupport.remove(streamId, this); + + final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); + this.sendProcessor.onNext(errorFrame); + } + + @Override + public void onComplete() { + if (this.done) { + return; + } + + this.done = true; + + if (S.getAndSet(this, Operators.cancelledSubscription()) == Operators.cancelledSubscription()) { + return; + } + + final int streamId = this.streamId; + + this.requesterResponderSupport.remove(streamId, this); + + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); + this.sendProcessor.onNext(completeFrame); + } + + @Override + public void handleRequestN(long n) { + this.s.request(n); + } + + @Override + public final void handleCancel() { + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription()) { + return; + } + + if (currentSubscription == null) { + // if subscription is null, it means that streams has not yet reassembled all the fragments + // and fragmentation of the first frame was cancelled before + S.lazySet(this, Operators.cancelledSubscription()); + + this.requesterResponderSupport.remove(this.streamId, this); + + final CompositeByteBuf frames = this.frames; + this.frames = null; + frames.release(); + + return; + } + + if (!S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + return; + } + + this.requesterResponderSupport.remove(this.streamId, this); + + currentSubscription.cancel(); + } + + @Override + public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLastPayload) { + final CompositeByteBuf frames = + ReassemblyUtils.addFollowingFrame(this.frames, followingFrame, this.maxInboundPayloadSize); + + if (!hasFollows) { + this.frames = null; + Payload payload; + try { + payload = this.payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frames); + logger.debug("Reassembly has failed", t); + + S.lazySet(this, Operators.cancelledSubscription()); + this.done = true; + // sends error frame from the responder side to tell that something went wrong + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + this.streamId, + new CanceledException("Failed to reassemble payload. Cause" + t.getMessage())); + this.sendProcessor.onNext(errorFrame); + return; + } + + Flux source = this.handler.requestStream(payload); + source.subscribe(this); + } + } + + @Override + public Context currentContext() { + return SendUtils.DISCARD_CONTEXT; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequesterFrameHandler.java b/rsocket-core/src/main/java/io/rsocket/core/RequesterFrameHandler.java new file mode 100644 index 000000000..1f7b09af8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequesterFrameHandler.java @@ -0,0 +1,43 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.Payload; +import java.util.concurrent.CancellationException; +import reactor.util.annotation.Nullable; + +interface RequesterFrameHandler extends FrameHandler { + + void handlePayload(Payload payload); + + @Override + default void handleCancel() { + handleError( + new CancellationException( + "Cancellation was received but should not be possible for current request type")); + } + + @Override + default void handleRequestN(long n) { + // no ops + } + + @Nullable + CompositeByteBuf getFrames(); + + void setFrames(@Nullable CompositeByteBuf reassembledFrames); +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java b/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java new file mode 100644 index 000000000..f5ddb199c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java @@ -0,0 +1,128 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.UnboundedProcessor; +import reactor.util.annotation.Nullable; + +class RequesterResponderSupport { + + private final int mtu; + private final int maxFrameLength; + private final int maxInboundPayloadSize; + private final PayloadDecoder payloadDecoder; + private final ByteBufAllocator allocator; + + @Nullable final StreamIdSupplier streamIdSupplier; + final IntObjectMap activeStreams; + + private final UnboundedProcessor sendProcessor; + + public RequesterResponderSupport( + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + PayloadDecoder payloadDecoder, + ByteBufAllocator allocator, + @Nullable StreamIdSupplier streamIdSupplier) { + + this.activeStreams = new IntObjectHashMap<>(); + this.mtu = mtu; + this.maxFrameLength = maxFrameLength; + this.maxInboundPayloadSize = maxInboundPayloadSize; + this.payloadDecoder = payloadDecoder; + this.allocator = allocator; + this.streamIdSupplier = streamIdSupplier; + this.sendProcessor = new UnboundedProcessor<>(); + } + + public int getMtu() { + return mtu; + } + + public int getMaxFrameLength() { + return maxFrameLength; + } + + public int getMaxInboundPayloadSize() { + return maxInboundPayloadSize; + } + + public PayloadDecoder getPayloadDecoder() { + return payloadDecoder; + } + + public ByteBufAllocator getAllocator() { + return allocator; + } + + public UnboundedProcessor getSendProcessor() { + return sendProcessor; + } + + /** + * Issues next {@code streamId} + * + * @return issued {@code streamId} + * @throws RuntimeException if the {@link RequesterResponderSupport} is terminated for any reason + */ + public int getNextStreamId() { + final StreamIdSupplier streamIdSupplier = this.streamIdSupplier; + if (streamIdSupplier != null) { + synchronized (this) { + return streamIdSupplier.nextStreamId(this.activeStreams); + } + } else { + throw new UnsupportedOperationException("Responder can not issue id"); + } + } + + /** + * Adds frameHandler and returns issued {@code streamId} back + * + * @param frameHandler to store + * @return issued {@code streamId} + * @throws RuntimeException if the {@link RequesterResponderSupport} is terminated for any reason + */ + public int addAndGetNextStreamId(FrameHandler frameHandler) { + final StreamIdSupplier streamIdSupplier = this.streamIdSupplier; + if (streamIdSupplier != null) { + final IntObjectMap activeStreams = this.activeStreams; + synchronized (this) { + final int streamId = streamIdSupplier.nextStreamId(activeStreams); + + activeStreams.put(streamId, frameHandler); + + return streamId; + } + } else { + throw new UnsupportedOperationException("Responder can not issue id"); + } + } + + /** + * Resolves {@link FrameHandler} by {@code streamId} + * + * @param streamId used to resolve {@link FrameHandler} + * @return {@link FrameHandler} or {@code null} + */ + @Nullable + public synchronized FrameHandler get(int streamId) { + return this.activeStreams.get(streamId); + } + + /** + * Removes {@link FrameHandler} if it is present and equals to the given one + * + * @param streamId to lookup for {@link FrameHandler} + * @param frameHandler instance to check with the found one + * @return {@code true} if there is {@link FrameHandler} for the given {@code streamId} and the + * instance equals to the passed one + */ + public synchronized boolean remove(int streamId, FrameHandler frameHandler) { + return this.activeStreams.remove(streamId, frameHandler); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ResponderFrameHandler.java b/rsocket-core/src/main/java/io/rsocket/core/ResponderFrameHandler.java new file mode 100644 index 000000000..27cc8db9a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ResponderFrameHandler.java @@ -0,0 +1,38 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +interface ResponderFrameHandler extends FrameHandler { + + Logger logger = LoggerFactory.getLogger(ResponderFrameHandler.class); + + @Override + default void handleComplete() {} + + @Override + default void handleError(Throwable t) { + logger.debug("Dropped error", t); + handleCancel(); + } + + @Override + default void handleRequestN(long n) { + // no ops + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java b/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java new file mode 100644 index 000000000..8f86d76e1 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java @@ -0,0 +1,326 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.isFragmentable; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCounted; +import io.rsocket.Payload; +import io.rsocket.exceptions.CanceledException; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.internal.UnboundedProcessor; +import java.util.function.Consumer; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; + +final class SendUtils { + private static final Consumer DROPPED_ELEMENTS_CONSUMER = + data -> { + try { + ReferenceCounted referenceCounted = (ReferenceCounted) data; + referenceCounted.release(); + } catch (Throwable e) { + // ignored + } + }; + + static final Context DISCARD_CONTEXT = Operators.enableOnDiscard(null, DROPPED_ELEMENTS_CONSUMER); + + static void sendReleasingPayload( + int streamId, + FrameType frameType, + int mtu, + Payload payload, + UnboundedProcessor sendProcessor, + ByteBufAllocator allocator, + boolean requester) { + + final boolean hasMetadata = payload.hasMetadata(); + final ByteBuf metadata = hasMetadata ? payload.metadata() : null; + final ByteBuf data = payload.data(); + + boolean fragmentable; + try { + fragmentable = isFragmentable(mtu, data, metadata, false); + } catch (IllegalReferenceCountException | NullPointerException e) { + sendTerminalFrame(streamId, frameType, sendProcessor, allocator, requester, false, e); + throw e; + } + + if (fragmentable) { + final ByteBuf slicedData = data.slice(); + final ByteBuf slicedMetadata = hasMetadata ? metadata.slice() : Unpooled.EMPTY_BUFFER; + + final ByteBuf first; + try { + first = + FragmentationUtils.encodeFirstFragment( + allocator, mtu, frameType, streamId, hasMetadata, slicedMetadata, slicedData); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, sendProcessor, allocator, requester, false, e); + throw e; + } + + sendProcessor.onNext(first); + + boolean complete = frameType == FrameType.NEXT_COMPLETE; + while (slicedData.isReadable() || slicedMetadata.isReadable()) { + final ByteBuf following; + try { + following = + FragmentationUtils.encodeFollowsFragment( + allocator, mtu, streamId, complete, slicedMetadata, slicedData); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, sendProcessor, allocator, requester, true, e); + throw e; + } + sendProcessor.onNext(following); + } + + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, sendProcessor, allocator, true, true, e); + throw e; + } + } else { + final ByteBuf dataRetainedSlice = data.retainedSlice(); + + final ByteBuf metadataRetainedSlice; + try { + metadataRetainedSlice = hasMetadata ? metadata.retainedSlice() : null; + } catch (IllegalReferenceCountException e) { + dataRetainedSlice.release(); + + sendTerminalFrame(streamId, frameType, sendProcessor, allocator, requester, false, e); + throw e; + } + + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + dataRetainedSlice.release(); + if (hasMetadata) { + metadataRetainedSlice.release(); + } + + sendTerminalFrame(streamId, frameType, sendProcessor, allocator, requester, false, e); + throw e; + } + + final ByteBuf requestFrame; + switch (frameType) { + case REQUEST_FNF: + requestFrame = + RequestFireAndForgetFrameCodec.encode( + allocator, streamId, false, metadataRetainedSlice, dataRetainedSlice); + break; + case REQUEST_RESPONSE: + requestFrame = + RequestResponseFrameCodec.encode( + allocator, streamId, false, metadataRetainedSlice, dataRetainedSlice); + break; + case PAYLOAD: + case NEXT: + case NEXT_COMPLETE: + requestFrame = + PayloadFrameCodec.encode( + allocator, + streamId, + false, + frameType == FrameType.NEXT_COMPLETE, + frameType != FrameType.PAYLOAD, + metadataRetainedSlice, + dataRetainedSlice); + break; + default: + throw new IllegalArgumentException("Unsupported frame type " + frameType); + } + + sendProcessor.onNext(requestFrame); + } + } + + static void sendReleasingPayload( + int streamId, + FrameType frameType, + long initialRequestN, + int mtu, + Payload payload, + UnboundedProcessor sendProcessor, + ByteBufAllocator allocator, + boolean complete) { + + final boolean hasMetadata = payload.hasMetadata(); + final ByteBuf metadata = hasMetadata ? payload.metadata() : null; + final ByteBuf data = payload.data(); + + boolean fragmentable; + try { + fragmentable = isFragmentable(mtu, data, metadata, true); + } catch (IllegalReferenceCountException | NullPointerException e) { + sendTerminalFrame(streamId, frameType, sendProcessor, allocator, true, false, e); + throw e; + } + + if (fragmentable) { + final ByteBuf slicedData = data.slice(); + final ByteBuf slicedMetadata = hasMetadata ? metadata.slice() : Unpooled.EMPTY_BUFFER; + + final ByteBuf first; + try { + first = + FragmentationUtils.encodeFirstFragment( + allocator, mtu, initialRequestN, frameType, streamId, slicedMetadata, slicedData); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, sendProcessor, allocator, true, false, e); + throw e; + } + + sendProcessor.onNext(first); + + while (slicedData.isReadable() || slicedMetadata.isReadable()) { + final ByteBuf following; + try { + following = + FragmentationUtils.encodeFollowsFragment( + allocator, mtu, streamId, complete, slicedMetadata, slicedData); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, sendProcessor, allocator, true, true, e); + throw e; + } + sendProcessor.onNext(following); + } + + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, sendProcessor, allocator, true, true, e); + throw e; + } + } else { + final ByteBuf dataRetainedSlice = data.retainedSlice(); + + final ByteBuf metadataRetainedSlice; + try { + metadataRetainedSlice = hasMetadata ? metadata.retainedSlice() : null; + } catch (IllegalReferenceCountException e) { + dataRetainedSlice.release(); + + sendTerminalFrame(streamId, frameType, sendProcessor, allocator, true, false, e); + throw e; + } + + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + dataRetainedSlice.release(); + if (hasMetadata) { + metadataRetainedSlice.release(); + } + + sendTerminalFrame(streamId, frameType, sendProcessor, allocator, true, false, e); + throw e; + } + + final ByteBuf requestFrame; + switch (frameType) { + case REQUEST_STREAM: + requestFrame = + RequestStreamFrameCodec.encode( + allocator, + streamId, + false, + initialRequestN, + metadataRetainedSlice, + dataRetainedSlice); + break; + case REQUEST_CHANNEL: + requestFrame = + RequestChannelFrameCodec.encode( + allocator, + streamId, + false, + complete, + initialRequestN, + metadataRetainedSlice, + dataRetainedSlice); + break; + default: + throw new IllegalArgumentException("Unsupported frame type " + frameType); + } + + sendProcessor.onNext(requestFrame); + } + } + + static void sendTerminalFrame( + int streamId, + FrameType frameType, + UnboundedProcessor sendProcessor, + ByteBufAllocator allocator, + boolean requester, + boolean onFollowingFrame, + Throwable t) { + + if (onFollowingFrame) { + if (requester) { + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + sendProcessor.onNext(cancelFrame); + } else { + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException( + "Failed to encode fragmented " + + frameType + + " frame. Cause: " + + t.getMessage())); + sendProcessor.onNext(errorFrame); + } + } else { + switch (frameType) { + case NEXT_COMPLETE: + case NEXT: + case PAYLOAD: + if (requester) { + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + sendProcessor.onNext(cancelFrame); + } else { + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException( + "Failed to encode " + frameType + " frame. Cause: " + t.getMessage())); + sendProcessor.onNext(errorFrame); + } + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/StateUtils.java b/rsocket-core/src/main/java/io/rsocket/core/StateUtils.java new file mode 100644 index 000000000..b3857bc12 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/StateUtils.java @@ -0,0 +1,385 @@ +package io.rsocket.core; + +import java.util.concurrent.atomic.AtomicLongFieldUpdater; + +final class StateUtils { + + /** Volatile Long Field bit mask that allows extract flags stored in the field */ + static final long FLAGS_MASK = + 0b111111111111111111111111111111111_0000000000000000000000000000000L; + /** Volatile Long Field bit mask that allows extract int RequestN stored in the field */ + static final long REQUEST_MASK = + 0b000000000000000000000000000000000_1111111111111111111111111111111L; + /** Bit Flag that indicates Requester Producer has been subscribed once */ + static final long SUBSCRIBED_FLAG = + 0b000000000000000000000000000000001_0000000000000000000000000000000L; + /** + * Bit Flag that indicates that sent first initial frame was sent (in case of requester) or + * consumed (if responder) + */ + static final long FIRST_FRAME_SENT_FLAG = + 0b000000000000000000000000000000010_0000000000000000000000000000000L; + /** Bit Flag that indicates that there is a frame being reassembled */ + static final long REASSEMBLING_FLAG = + 0b000000000000000000000000000000100_0000000000000000000000000000000L; + /** + * Bit Flag that indicates requestChannel stream is half terminated. In this case flag indicates + * that the inbound is terminated + */ + static final long INBOUND_TERMINATED_FLAG = + 0b000000000000000000000000000001000_0000000000000000000000000000000L; + /** + * Bit Flag that indicates requestChannel stream is half terminated. In this case flag indicates + * that the outbound is terminated + */ + static final long OUTBOUND_TERMINATED_FLAG = + 0b000000000000000000000000000010000_0000000000000000000000000000000L; + /** Initial state for any request operator */ + static final long UNSUBSCRIBED_STATE = + 0b000000000000000000000000000000000_0000000000000000000000000000000L; + /** State that indicates request operator was terminated */ + static final long TERMINATED_STATE = + 0b100000000000000000000000000000000_0000000000000000000000000000000L; + + /** + * Adds (if possible) to the given state the {@link #SUBSCRIBED_FLAG} flag which indicates that + * the given stream has already been subscribed once + * + *

Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been subscribed once + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markSubscribed(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & SUBSCRIBED_FLAG) == SUBSCRIBED_FLAG) { + return state; + } + + if (updater.compareAndSet(instance, state, state | SUBSCRIBED_FLAG)) { + return state; + } + } + } + + /** + * Indicates that the given stream has already been subscribed once + * + * @param state to check whether stream is subscribed + * @return true if the {@link #SUBSCRIBED_FLAG} flag is set + */ + static boolean isSubscribed(long state) { + return (state & SUBSCRIBED_FLAG) == SUBSCRIBED_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #FIRST_FRAME_SENT_FLAG} flag which indicates + * that the first frame has already set and logical stream has already been established. + * + *

Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been established once + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markFirstFrameSent(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & FIRST_FRAME_SENT_FLAG) == FIRST_FRAME_SENT_FLAG) { + return state; + } + + if (updater.compareAndSet(instance, state, state | FIRST_FRAME_SENT_FLAG)) { + return state; + } + } + } + + /** + * Indicates that the first frame which established logical stream has already been sent + * + * @param state to check whether stream is established + * @return true if the {@link #FIRST_FRAME_SENT_FLAG} flag is set + */ + static boolean isFirstFrameSent(long state) { + return (state & FIRST_FRAME_SENT_FLAG) == FIRST_FRAME_SENT_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #REASSEMBLING_FLAG} flag which indicates that + * there is a payload reassembling in progress. + * + *

Note, the flag will not be added if the stream has already been terminated + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markReassembling(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if (updater.compareAndSet(instance, state, state | REASSEMBLING_FLAG)) { + return state; + } + } + } + + /** + * Removes (if possible) from the given state the {@link #REASSEMBLING_FLAG} flag which indicates + * that a payload reassembly process is completed. + * + *

Note, the flag will not be removed if the stream has already been terminated + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markReassembled(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if (updater.compareAndSet(instance, state, state & ~REASSEMBLING_FLAG)) { + return state; + } + } + } + + /** + * Indicates that a payload reassembly process is completed. + * + * @param state to check whether there is reassembly in progress + * @return true if the {@link #REASSEMBLING_FLAG} flag is set + */ + static boolean isReassembling(long state) { + return (state & REASSEMBLING_FLAG) == REASSEMBLING_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #INBOUND_TERMINATED_FLAG} flag which indicates + * that an inbound channel of a bidirectional stream is terminated. + * + *

Note, this action will have no effect if the stream has already been terminated or if + * the {@link #INBOUND_TERMINATED_FLAG} flag has already been set.
+ * Note, if the outbound stream has already been terminated, then the result state will be + * {@link #TERMINATED_STATE} + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markInboundTerminated(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & INBOUND_TERMINATED_FLAG) == INBOUND_TERMINATED_FLAG) { + return state; + } + + if ((state & OUTBOUND_TERMINATED_FLAG) == OUTBOUND_TERMINATED_FLAG) { + if (updater.compareAndSet(instance, state, TERMINATED_STATE)) { + return state; + } + } else { + if (updater.compareAndSet(instance, state, state | INBOUND_TERMINATED_FLAG)) { + return state; + } + } + } + } + + /** + * Indicates that a the inbound channel of a bidirectional stream is terminated. + * + * @param state to check whether it has {@link #INBOUND_TERMINATED_FLAG} set + * @return true if the {@link #INBOUND_TERMINATED_FLAG} flag is set + */ + static boolean isInboundTerminated(long state) { + return (state & INBOUND_TERMINATED_FLAG) == INBOUND_TERMINATED_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #OUTBOUND_TERMINATED_FLAG} flag which + * indicates that an outbound channel of a bidirectional stream is terminated. + * + *

Note, this action will have no effect if the stream has already been terminated or if + * the {@link #OUTBOUND_TERMINATED_FLAG} flag has already been set.
+ * Note, if the {@code checkEstablishment} parameter is {@code true} and the logical stream + * is not established, then the result state will be {@link #TERMINATED_STATE}
+ * Note, if the inbound stream has already been terminated, then the result state will be + * {@link #TERMINATED_STATE} + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param checkEstablishment indicates whether {@link #FIRST_FRAME_SENT_FLAG} should be checked to + * make final decision + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markOutboundTerminated( + AtomicLongFieldUpdater updater, T instance, boolean checkEstablishment) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & OUTBOUND_TERMINATED_FLAG) == OUTBOUND_TERMINATED_FLAG) { + return state; + } + + if ((checkEstablishment && !isFirstFrameSent(state)) + || (state & INBOUND_TERMINATED_FLAG) == INBOUND_TERMINATED_FLAG) { + if (updater.compareAndSet(instance, state, TERMINATED_STATE)) { + return state; + } + } else { + if (updater.compareAndSet(instance, state, state | OUTBOUND_TERMINATED_FLAG)) { + return state; + } + } + } + } + + /** + * Indicates that a the outbound channel of a bidirectional stream is terminated. + * + * @param state to check whether it has {@link #OUTBOUND_TERMINATED_FLAG} set + * @return true if the {@link #OUTBOUND_TERMINATED_FLAG} flag is set + */ + static boolean isOutboundTerminated(long state) { + return (state & OUTBOUND_TERMINATED_FLAG) == OUTBOUND_TERMINATED_FLAG; + } + + /** + * Makes current state a {@link #TERMINATED_STATE} + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markTerminated(AtomicLongFieldUpdater updater, T instance) { + return updater.getAndSet(instance, TERMINATED_STATE); + } + + /** + * Makes current state a {@link #TERMINATED_STATE} using {@link + * AtomicLongFieldUpdater#lazySet(Object, long)} + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + */ + static void lazyTerminate(AtomicLongFieldUpdater updater, T instance) { + updater.lazySet(instance, TERMINATED_STATE); + } + + /** + * Indicates that a the outbound channel of a bidirectional stream is terminated. + * + * @param state to check whether it has {@link #OUTBOUND_TERMINATED_FLAG} set + * @return true if the {@link #OUTBOUND_TERMINATED_FLAG} flag is set + */ + static boolean isTerminated(long state) { + return state == TERMINATED_STATE; + } + + /** + * Shortcut for {@link #isSubscribed} {@code ||} {@link #isTerminated} methods + * + * @param state to check flags on + * @return true if state is terminated or has flag subscribed + */ + static boolean isSubscribedOrTerminated(long state) { + return state == TERMINATED_STATE || (state & SUBSCRIBED_FLAG) == SUBSCRIBED_FLAG; + } + + /** + * @param updater + * @param instance + * @param toAdd + * @param + * @return + */ + static long addRequestN(AtomicLongFieldUpdater updater, T instance, long toAdd) { + long currentState, flags, requestN, nextRequestN; + for (; ; ) { + currentState = updater.get(instance); + + if (currentState == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + requestN = currentState & REQUEST_MASK; + if (requestN == REQUEST_MASK) { + return currentState; + } + + flags = currentState & FLAGS_MASK; + nextRequestN = addRequestN(requestN, toAdd); + + if (updater.compareAndSet(instance, currentState, nextRequestN | flags)) { + return currentState; + } + } + } + + static long addRequestN(long a, long b) { + long res = a + b; + if (res < 0 || res > REQUEST_MASK) { + return REQUEST_MASK; + } + return res; + } + + static boolean hasRequested(long state) { + return (state & REQUEST_MASK) > 0; + } + + static long extractRequestN(long state) { + long requestN = state & REQUEST_MASK; + + if (requestN == REQUEST_MASK) { + return REQUEST_MASK; + } + + return requestN; + } + + static boolean isMaxAllowedRequestN(long n) { + return n >= REQUEST_MASK; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java deleted file mode 100644 index 6eebd676c..000000000 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import static io.rsocket.fragmentation.FrameFragmenter.fragmentFrame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.rsocket.DuplexConnection; -import io.rsocket.frame.FrameHeaderCodec; -import io.rsocket.frame.FrameType; -import java.util.Objects; -import org.reactivestreams.Publisher; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -/** - * A {@link DuplexConnection} implementation that fragments and reassembles {@link ByteBuf}s. - * - * @see Fragmentation - * and Reassembly - */ -public final class FragmentationDuplexConnection extends ReassemblyDuplexConnection - implements DuplexConnection { - - public static final int MIN_MTU_SIZE = 64; - - private static final Logger logger = LoggerFactory.getLogger(FragmentationDuplexConnection.class); - - final DuplexConnection delegate; - final int mtu; - final String type; - - /** - * Class constructor. - * - * @param delegate the underlying connection - * @param mtu the fragment size, greater than {@link #MIN_MTU_SIZE} - * @param maxInboundPayloadSize the maximum payload size, which can be reassembled from multiple - * fragments - * @param type a label to use for logging purposes - */ - public FragmentationDuplexConnection( - DuplexConnection delegate, int mtu, int maxInboundPayloadSize, String type) { - super(delegate, maxInboundPayloadSize); - - Objects.requireNonNull(delegate, "delegate must not be null"); - this.delegate = delegate; - this.mtu = assertMtu(mtu); - this.type = type; - } - - private boolean shouldFragment(FrameType frameType, int readableBytes) { - return frameType.isFragmentable() && readableBytes > mtu; - } - - public static int assertMtu(int mtu) { - if (mtu > 0 && mtu < MIN_MTU_SIZE || mtu < 0) { - String msg = - String.format( - "The smallest allowed mtu size is %d bytes, provided: %d", MIN_MTU_SIZE, mtu); - throw new IllegalArgumentException(msg); - } else { - return mtu; - } - } - - @Override - public Mono send(Publisher frames) { - return Flux.from(frames).concatMap(this::sendOne).then(); - } - - @Override - public Mono sendOne(ByteBuf frame) { - FrameType frameType = FrameHeaderCodec.frameType(frame); - int readableBytes = frame.readableBytes(); - if (!shouldFragment(frameType, readableBytes)) { - return delegate.sendOne(frame); - } - Flux fragments = Flux.from(fragmentFrame(alloc(), mtu, frame, frameType)); - if (logger.isDebugEnabled()) { - fragments = - fragments.doOnNext( - byteBuf -> { - logger.debug( - "{} - stream id {} - frame type {} - \n {}", - type, - FrameHeaderCodec.streamId(byteBuf), - FrameHeaderCodec.frameType(byteBuf), - ByteBufUtil.prettyHexDump(byteBuf)); - }); - } - return delegate.send(fragments); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java deleted file mode 100644 index fcb6198a3..000000000 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java +++ /dev/null @@ -1,235 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.netty.util.ReferenceCountUtil; -import io.rsocket.frame.FrameHeaderCodec; -import io.rsocket.frame.FrameType; -import io.rsocket.frame.PayloadFrameCodec; -import io.rsocket.frame.RequestChannelFrameCodec; -import io.rsocket.frame.RequestFireAndForgetFrameCodec; -import io.rsocket.frame.RequestResponseFrameCodec; -import io.rsocket.frame.RequestStreamFrameCodec; -import java.util.function.Consumer; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.core.publisher.SynchronousSink; - -/** - * The implementation of the RSocket fragmentation behavior. - * - * @see Fragmentation - * and Reassembly - */ -final class FrameFragmenter { - static Publisher fragmentFrame( - ByteBufAllocator allocator, int mtu, final ByteBuf frame, FrameType frameType) { - ByteBuf metadata = getMetadata(frame, frameType); - ByteBuf data = getData(frame, frameType); - int streamId = FrameHeaderCodec.streamId(frame); - return Flux.generate( - new Consumer>() { - boolean first = true; - - @Override - public void accept(SynchronousSink sink) { - ByteBuf byteBuf; - if (first) { - first = false; - byteBuf = - encodeFirstFragment( - allocator, mtu, frame, frameType, streamId, metadata, data); - } else { - byteBuf = encodeFollowsFragment(allocator, mtu, streamId, metadata, data); - } - - sink.next(byteBuf); - if (!metadata.isReadable() && !data.isReadable()) { - sink.complete(); - } - } - }) - .doFinally(signalType -> ReferenceCountUtil.safeRelease(frame)); - } - - static ByteBuf encodeFirstFragment( - ByteBufAllocator allocator, - int mtu, - ByteBuf frame, - FrameType frameType, - int streamId, - ByteBuf metadata, - ByteBuf data) { - // subtract the header bytes - int remaining = mtu - FrameHeaderCodec.size(); - - // substract the initial request n - switch (frameType) { - case REQUEST_STREAM: - case REQUEST_CHANNEL: - remaining -= Integer.BYTES; - break; - default: - } - - ByteBuf metadataFragment = null; - if (metadata.isReadable()) { - // subtract the metadata frame length - remaining -= 3; - int r = Math.min(remaining, metadata.readableBytes()); - remaining -= r; - metadataFragment = metadata.readRetainedSlice(r); - } - - ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; - if (remaining > 0 && data.isReadable()) { - int r = Math.min(remaining, data.readableBytes()); - dataFragment = data.readRetainedSlice(r); - } - - switch (frameType) { - case REQUEST_FNF: - return RequestFireAndForgetFrameCodec.encode( - allocator, streamId, true, metadataFragment, dataFragment); - case REQUEST_STREAM: - return RequestStreamFrameCodec.encode( - allocator, - streamId, - true, - RequestStreamFrameCodec.initialRequestN(frame), - metadataFragment, - dataFragment); - case REQUEST_RESPONSE: - return RequestResponseFrameCodec.encode( - allocator, streamId, true, metadataFragment, dataFragment); - case REQUEST_CHANNEL: - return RequestChannelFrameCodec.encode( - allocator, - streamId, - true, - false, - RequestChannelFrameCodec.initialRequestN(frame), - metadataFragment, - dataFragment); - // Payload and synthetic types - case PAYLOAD: - return PayloadFrameCodec.encode( - allocator, streamId, true, false, false, metadataFragment, dataFragment); - case NEXT: - return PayloadFrameCodec.encode( - allocator, streamId, true, false, true, metadataFragment, dataFragment); - case NEXT_COMPLETE: - return PayloadFrameCodec.encode( - allocator, streamId, true, true, true, metadataFragment, dataFragment); - case COMPLETE: - return PayloadFrameCodec.encode( - allocator, streamId, true, true, false, metadataFragment, dataFragment); - default: - throw new IllegalStateException("unsupported fragment type: " + frameType); - } - } - - static ByteBuf encodeFollowsFragment( - ByteBufAllocator allocator, int mtu, int streamId, ByteBuf metadata, ByteBuf data) { - // subtract the header bytes - int remaining = mtu - FrameHeaderCodec.size(); - - ByteBuf metadataFragment = null; - if (metadata.isReadable()) { - // subtract the metadata frame length - remaining -= 3; - int r = Math.min(remaining, metadata.readableBytes()); - remaining -= r; - metadataFragment = metadata.readRetainedSlice(r); - } - - ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; - if (remaining > 0 && data.isReadable()) { - int r = Math.min(remaining, data.readableBytes()); - dataFragment = data.readRetainedSlice(r); - } - - boolean follows = data.isReadable() || metadata.isReadable(); - return PayloadFrameCodec.encode( - allocator, streamId, follows, false, true, metadataFragment, dataFragment); - } - - static ByteBuf getMetadata(ByteBuf frame, FrameType frameType) { - boolean hasMetadata = FrameHeaderCodec.hasMetadata(frame); - if (hasMetadata) { - ByteBuf metadata; - switch (frameType) { - case REQUEST_FNF: - metadata = RequestFireAndForgetFrameCodec.metadata(frame); - break; - case REQUEST_STREAM: - metadata = RequestStreamFrameCodec.metadata(frame); - break; - case REQUEST_RESPONSE: - metadata = RequestResponseFrameCodec.metadata(frame); - break; - case REQUEST_CHANNEL: - metadata = RequestChannelFrameCodec.metadata(frame); - break; - // Payload and synthetic types - case PAYLOAD: - case NEXT: - case NEXT_COMPLETE: - case COMPLETE: - metadata = PayloadFrameCodec.metadata(frame); - break; - default: - throw new IllegalStateException("unsupported fragment type"); - } - return metadata; - } else { - return Unpooled.EMPTY_BUFFER; - } - } - - static ByteBuf getData(ByteBuf frame, FrameType frameType) { - ByteBuf data; - switch (frameType) { - case REQUEST_FNF: - data = RequestFireAndForgetFrameCodec.data(frame); - break; - case REQUEST_STREAM: - data = RequestStreamFrameCodec.data(frame); - break; - case REQUEST_RESPONSE: - data = RequestResponseFrameCodec.data(frame); - break; - case REQUEST_CHANNEL: - data = RequestChannelFrameCodec.data(frame); - break; - // Payload and synthetic types - case PAYLOAD: - case NEXT: - case NEXT_COMPLETE: - case COMPLETE: - data = PayloadFrameCodec.data(frame); - break; - default: - throw new IllegalStateException("unsupported fragment type"); - } - return data; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java deleted file mode 100644 index d1adbfdf7..000000000 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java +++ /dev/null @@ -1,342 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.CompositeByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.util.ReferenceCountUtil; -import io.netty.util.collection.IntObjectHashMap; -import io.netty.util.collection.IntObjectMap; -import io.rsocket.frame.*; -import java.util.concurrent.atomic.AtomicBoolean; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.Disposable; -import reactor.core.publisher.SynchronousSink; -import reactor.util.annotation.Nullable; - -/** - * The implementation of the RSocket reassembly behavior. - * - * @see Fragmentation - * and Reassembly - */ -final class FrameReassembler extends AtomicBoolean implements Disposable { - - private static final long serialVersionUID = -4394598098863449055L; - - private static final Logger logger = LoggerFactory.getLogger(FrameReassembler.class); - - final IntObjectMap headers; - final IntObjectMap metadata; - final IntObjectMap data; - - final ByteBufAllocator allocator; - final int maxInboundPayloadSize; - - public FrameReassembler(ByteBufAllocator allocator, int maxInboundPayloadSize) { - this.allocator = allocator; - this.maxInboundPayloadSize = maxInboundPayloadSize; - this.headers = new IntObjectHashMap<>(); - this.metadata = new IntObjectHashMap<>(); - this.data = new IntObjectHashMap<>(); - } - - @Override - public void dispose() { - if (compareAndSet(false, true)) { - synchronized (FrameReassembler.this) { - for (ByteBuf byteBuf : headers.values()) { - ReferenceCountUtil.safeRelease(byteBuf); - } - headers.clear(); - - for (ByteBuf byteBuf : metadata.values()) { - ReferenceCountUtil.safeRelease(byteBuf); - } - metadata.clear(); - - for (ByteBuf byteBuf : data.values()) { - ReferenceCountUtil.safeRelease(byteBuf); - } - data.clear(); - } - } - } - - @Override - public boolean isDisposed() { - return get(); - } - - @Nullable - synchronized ByteBuf getHeader(int streamId) { - return headers.get(streamId); - } - - synchronized CompositeByteBuf getMetadata(int streamId) { - CompositeByteBuf byteBuf = metadata.get(streamId); - - if (byteBuf == null) { - byteBuf = allocator.compositeBuffer(); - metadata.put(streamId, byteBuf); - } - - return byteBuf; - } - - synchronized int getMetadataSize(int streamId) { - CompositeByteBuf byteBuf = metadata.get(streamId); - - if (byteBuf == null) { - return 0; - } - - return byteBuf.readableBytes(); - } - - synchronized CompositeByteBuf getData(int streamId) { - CompositeByteBuf byteBuf = data.get(streamId); - - if (byteBuf == null) { - byteBuf = allocator.compositeBuffer(); - data.put(streamId, byteBuf); - } - - return byteBuf; - } - - synchronized int getDataSize(int streamId) { - CompositeByteBuf byteBuf = data.get(streamId); - - if (byteBuf == null) { - return 0; - } - - return byteBuf.readableBytes(); - } - - @Nullable - synchronized ByteBuf removeHeader(int streamId) { - return headers.remove(streamId); - } - - @Nullable - synchronized CompositeByteBuf removeMetadata(int streamId) { - return metadata.remove(streamId); - } - - @Nullable - synchronized CompositeByteBuf removeData(int streamId) { - return data.remove(streamId); - } - - synchronized void putHeader(int streamId, ByteBuf header) { - headers.put(streamId, header); - } - - void cancelAssemble(int streamId) { - ByteBuf header = removeHeader(streamId); - CompositeByteBuf metadata = removeMetadata(streamId); - CompositeByteBuf data = removeData(streamId); - - if (header != null) { - ReferenceCountUtil.safeRelease(header); - } - - if (metadata != null) { - ReferenceCountUtil.safeRelease(metadata); - } - - if (data != null) { - ReferenceCountUtil.safeRelease(data); - } - } - - void handleNoFollowsFlag(ByteBuf frame, SynchronousSink sink, int streamId) { - ByteBuf header = removeHeader(streamId); - if (header != null) { - - int maxReassemblySize = this.maxInboundPayloadSize; - if (maxReassemblySize != Integer.MAX_VALUE) { - int currentPayloadSize = getMetadataSize(streamId) + getDataSize(streamId); - if (currentPayloadSize + frame.readableBytes() - FrameHeaderCodec.size() - > maxReassemblySize) { - frame.release(); - throw new IllegalStateException("Reassembled payload went out of allowed size"); - } - } - - if (FrameHeaderCodec.hasMetadata(header)) { - ByteBuf assembledFrame = assembleFrameWithMetadata(frame, streamId, header); - sink.next(assembledFrame); - } else { - ByteBuf data = assembleData(frame, streamId); - ByteBuf assembledFrame = FragmentationCodec.encode(allocator, header, data); - sink.next(assembledFrame); - } - frame.release(); - } else { - sink.next(frame); - } - } - - void handleFollowsFlag(ByteBuf frame, int streamId, FrameType frameType) { - - int maxReassemblySize = this.maxInboundPayloadSize; - if (maxReassemblySize != Integer.MAX_VALUE) { - int currentPayloadSize = getMetadataSize(streamId) + getDataSize(streamId); - if (currentPayloadSize + frame.readableBytes() - FrameHeaderCodec.size() - > maxReassemblySize) { - frame.release(); - throw new IllegalStateException("Reassembled payload went out of allowed size"); - } - } - - ByteBuf header = getHeader(streamId); - if (header == null) { - header = frame.copy(frame.readerIndex(), FrameHeaderCodec.size()); - - if (frameType == FrameType.REQUEST_CHANNEL || frameType == FrameType.REQUEST_STREAM) { - long i = RequestChannelFrameCodec.initialRequestN(frame); - header.writeInt(i > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) i); - } - putHeader(streamId, header); - } - - if (FrameHeaderCodec.hasMetadata(frame)) { - CompositeByteBuf metadata = getMetadata(streamId); - switch (frameType) { - case REQUEST_FNF: - metadata.addComponents(true, RequestFireAndForgetFrameCodec.metadata(frame).retain()); - break; - case REQUEST_STREAM: - metadata.addComponents(true, RequestStreamFrameCodec.metadata(frame).retain()); - break; - case REQUEST_RESPONSE: - metadata.addComponents(true, RequestResponseFrameCodec.metadata(frame).retain()); - break; - case REQUEST_CHANNEL: - metadata.addComponents(true, RequestChannelFrameCodec.metadata(frame).retain()); - break; - // Payload and synthetic types - case PAYLOAD: - case NEXT: - case NEXT_COMPLETE: - case COMPLETE: - metadata.addComponents(true, PayloadFrameCodec.metadata(frame).retain()); - break; - default: - throw new IllegalStateException("unsupported fragment type"); - } - } - - ByteBuf data; - switch (frameType) { - case REQUEST_FNF: - data = RequestFireAndForgetFrameCodec.data(frame).retain(); - break; - case REQUEST_STREAM: - data = RequestStreamFrameCodec.data(frame).retain(); - break; - case REQUEST_RESPONSE: - data = RequestResponseFrameCodec.data(frame).retain(); - break; - case REQUEST_CHANNEL: - data = RequestChannelFrameCodec.data(frame).retain(); - break; - // Payload and synthetic types - case PAYLOAD: - case NEXT: - case NEXT_COMPLETE: - case COMPLETE: - data = PayloadFrameCodec.data(frame).retain(); - break; - default: - frame.release(); - throw new IllegalStateException("unsupported fragment type"); - } - - getData(streamId).addComponents(true, data); - frame.release(); - } - - void reassembleFrame(ByteBuf frame, SynchronousSink sink) { - try { - FrameType frameType = FrameHeaderCodec.frameType(frame); - int streamId = FrameHeaderCodec.streamId(frame); - switch (frameType) { - case CANCEL: - case ERROR: - cancelAssemble(streamId); - } - - if (!frameType.isFragmentable()) { - sink.next(frame); - return; - } - - boolean hasFollows = FrameHeaderCodec.hasFollows(frame); - - if (hasFollows) { - handleFollowsFlag(frame, streamId, frameType); - } else { - handleNoFollowsFlag(frame, sink, streamId); - } - - } catch (Throwable t) { - logger.error("error reassemble frame", t); - sink.error(t); - } - } - - private ByteBuf assembleFrameWithMetadata(ByteBuf frame, int streamId, ByteBuf header) { - ByteBuf metadata; - CompositeByteBuf cm = removeMetadata(streamId); - - ByteBuf decodedMetadata = PayloadFrameCodec.metadata(frame); - if (decodedMetadata != null) { - if (cm != null) { - metadata = cm.addComponents(true, decodedMetadata.retain()); - } else { - metadata = PayloadFrameCodec.metadata(frame).retain(); - } - } else { - metadata = cm; - } - - ByteBuf data = assembleData(frame, streamId); - - return FragmentationCodec.encode(allocator, header, metadata, data); - } - - private ByteBuf assembleData(ByteBuf frame, int streamId) { - ByteBuf data; - CompositeByteBuf cd = removeData(streamId); - if (cd != null) { - cd.addComponents(true, PayloadFrameCodec.data(frame).retain()); - data = cd; - } else { - data = Unpooled.EMPTY_BUFFER; - } - - return data; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/ReassemblyDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/ReassemblyDuplexConnection.java deleted file mode 100644 index 03f97c75d..000000000 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/ReassemblyDuplexConnection.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.DuplexConnection; -import io.rsocket.frame.FrameLengthCodec; -import java.util.Objects; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -/** - * A {@link DuplexConnection} implementation that reassembles {@link ByteBuf}s. - * - * @see Fragmentation - * and Reassembly - */ -public class ReassemblyDuplexConnection implements DuplexConnection { - private final DuplexConnection delegate; - private final FrameReassembler frameReassembler; - - /** Constructor with the underlying delegate to receive frames from. */ - public ReassemblyDuplexConnection(DuplexConnection delegate, int maxInboundPayloadSize) { - Objects.requireNonNull(delegate, "delegate must not be null"); - this.delegate = delegate; - this.frameReassembler = new FrameReassembler(delegate.alloc(), maxInboundPayloadSize); - - delegate.onClose().doFinally(s -> frameReassembler.dispose()).subscribe(); - } - - public static int assertInboundPayloadSize(int inboundPayloadSize) { - if (inboundPayloadSize < FragmentationDuplexConnection.MIN_MTU_SIZE) { - String msg = - String.format( - "The min allowed inboundPayloadSize size is %d bytes, provided: %d", - FrameLengthCodec.FRAME_LENGTH_MASK, inboundPayloadSize); - throw new IllegalArgumentException(msg); - } else { - return inboundPayloadSize; - } - } - - @Override - public Mono send(Publisher frames) { - return delegate.send(frames); - } - - @Override - public Mono sendOne(ByteBuf frame) { - return delegate.sendOne(frame); - } - - @Override - public Flux receive() { - return delegate.receive().handle(frameReassembler::reassembleFrame); - } - - @Override - public ByteBufAllocator alloc() { - return delegate.alloc(); - } - - @Override - public Mono onClose() { - return delegate.onClose(); - } - - @Override - public void dispose() { - delegate.dispose(); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/SynchronizedIntObjectHashMap.java b/rsocket-core/src/main/java/io/rsocket/internal/SynchronizedIntObjectHashMap.java deleted file mode 100644 index fd6bf0aed..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/SynchronizedIntObjectHashMap.java +++ /dev/null @@ -1,748 +0,0 @@ -/* - * Copyright 2014 The Netty Project - * - * The Netty Project licenses this file to you under the Apache License, version 2.0 (the - * "License"); you may not use this file except in compliance with the License. You may obtain a - * copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License - * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express - * or implied. See the License for the specific language governing permissions and limitations under - * the License. - */ - -package io.rsocket.internal; - -import static io.netty.util.internal.MathUtil.safeFindNextPositivePowerOfTwo; - -import io.netty.util.collection.IntObjectMap; -import java.util.AbstractCollection; -import java.util.AbstractSet; -import java.util.Arrays; -import java.util.Collection; -import java.util.Iterator; -import java.util.Map; -import java.util.NoSuchElementException; -import java.util.Set; - -/** - * A hash map implementation of {@link IntObjectMap} that uses open addressing for keys. To minimize - * the memory footprint, this class uses open addressing rather than chaining. Collisions are - * resolved using linear probing. Deletions implement compaction, so cost of remove can approach - * O(N) for full maps, which makes a small loadFactor recommended. - * - * @param The value type stored in the map. - */ -public class SynchronizedIntObjectHashMap implements IntObjectMap { - - /** Default initial capacity. Used if not specified in the constructor */ - public static final int DEFAULT_CAPACITY = 8; - - /** Default load factor. Used if not specified in the constructor */ - public static final float DEFAULT_LOAD_FACTOR = 0.5f; - - /** - * Placeholder for null values, so we can use the actual null to mean available. (Better than - * using a placeholder for available: less references for GC processing.) - */ - private static final Object NULL_VALUE = new Object(); - - /** The maximum number of elements allowed without allocating more space. */ - private int maxSize; - - /** The load factor for the map. Used to calculate {@link #maxSize}. */ - private final float loadFactor; - - private int[] keys; - private V[] values; - private int size; - private int mask; - - private final Set keySet = new KeySet(); - private final Set> entrySet = new EntrySet(); - private final Iterable> entries = PrimitiveIterator::new; - - public SynchronizedIntObjectHashMap() { - this(DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR); - } - - public SynchronizedIntObjectHashMap(int initialCapacity) { - this(initialCapacity, DEFAULT_LOAD_FACTOR); - } - - public SynchronizedIntObjectHashMap(int initialCapacity, float loadFactor) { - if (loadFactor <= 0.0f || loadFactor > 1.0f) { - // Cannot exceed 1 because we can never store more than capacity elements; - // using a bigger loadFactor would trigger rehashing before the desired load is reached. - throw new IllegalArgumentException("loadFactor must be > 0 and <= 1"); - } - - this.loadFactor = loadFactor; - - // Adjust the initial capacity if necessary. - int capacity = safeFindNextPositivePowerOfTwo(initialCapacity); - mask = capacity - 1; - - // Allocate the arrays. - keys = new int[capacity]; - @SuppressWarnings({"unchecked", "SuspiciousArrayCast"}) - V[] temp = (V[]) new Object[capacity]; - values = temp; - - // Initialize the maximum size value. - maxSize = calcMaxSize(capacity); - } - - private static T toExternal(T value) { - assert value != null : "null is not a legitimate internal value. Concurrent Modification?"; - return value == NULL_VALUE ? null : value; - } - - @SuppressWarnings("unchecked") - private static T toInternal(T value) { - return value == null ? (T) NULL_VALUE : value; - } - - public synchronized V[] getValuesCopy() { - V[] values = this.values; - return Arrays.copyOf(values, values.length); - } - - @Override - public synchronized V get(int key) { - int index = indexOf(key); - return index == -1 ? null : toExternal(values[index]); - } - - @Override - public synchronized V put(int key, V value) { - int startIndex = hashIndex(key); - int index = startIndex; - - for (; ; ) { - if (values[index] == null) { - // Found empty slot, use it. - keys[index] = key; - values[index] = toInternal(value); - growSize(); - return null; - } - if (keys[index] == key) { - // Found existing entry with this key, just replace the value. - V previousValue = values[index]; - values[index] = toInternal(value); - return toExternal(previousValue); - } - - // Conflict, keep probing ... - if ((index = probeNext(index)) == startIndex) { - // Can only happen if the map was full at MAX_ARRAY_SIZE and couldn't grow. - throw new IllegalStateException("Unable to insert"); - } - } - } - - @Override - public synchronized void putAll(Map sourceMap) { - if (sourceMap instanceof SynchronizedIntObjectHashMap) { - // Optimization - iterate through the arrays. - @SuppressWarnings("unchecked") - SynchronizedIntObjectHashMap source = (SynchronizedIntObjectHashMap) sourceMap; - for (int i = 0; i < source.values.length; ++i) { - V sourceValue = source.values[i]; - if (sourceValue != null) { - put(source.keys[i], sourceValue); - } - } - return; - } - - // Otherwise, just add each entry. - for (Entry entry : sourceMap.entrySet()) { - put(entry.getKey(), entry.getValue()); - } - } - - @Override - public synchronized V remove(int key) { - int index = indexOf(key); - if (index == -1) { - return null; - } - - V prev = values[index]; - removeAt(index); - return toExternal(prev); - } - - @Override - public synchronized int size() { - return size; - } - - @Override - public synchronized boolean isEmpty() { - return size == 0; - } - - @Override - public synchronized void clear() { - Arrays.fill(keys, 0); - Arrays.fill(values, null); - size = 0; - } - - @Override - public synchronized boolean containsKey(int key) { - return indexOf(key) >= 0; - } - - @Override - public synchronized boolean containsValue(Object value) { - @SuppressWarnings("unchecked") - V v1 = toInternal((V) value); - for (V v2 : values) { - // The map supports null values; this will be matched as NULL_VALUE.equals(NULL_VALUE). - if (v2 != null && v2.equals(v1)) { - return true; - } - } - return false; - } - - @Override - public synchronized Iterable> entries() { - return entries; - } - - @Override - public synchronized Collection values() { - return new AbstractCollection() { - @Override - public Iterator iterator() { - return new Iterator() { - final PrimitiveIterator iter = new PrimitiveIterator(); - - @Override - public boolean hasNext() { - return iter.hasNext(); - } - - @Override - public V next() { - return iter.next().value(); - } - - @Override - public void remove() { - throw new UnsupportedOperationException(); - } - }; - } - - @Override - public int size() { - return size; - } - }; - } - - @Override - public synchronized int hashCode() { - // Hashcode is based on all non-zero, valid keys. We have to scan the whole keys - // array, which may have different lengths for two maps of same size(), so the - // capacity cannot be used as input for hashing but the size can. - int hash = size; - for (int key : keys) { - // 0 can be a valid key or unused slot, but won't impact the hashcode in either case. - // This way we can use a cheap loop without conditionals, or hard-to-unroll operations, - // or the devastatingly bad memory locality of visiting value objects. - // Also, it's important to use a hash function that does not depend on the ordering - // of terms, only their values; since the map is an unordered collection and - // entries can end up in different positions in different maps that have the same - // elements, but with different history of puts/removes, due to conflicts. - hash ^= hashCode(key); - } - return hash; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (!(obj instanceof IntObjectMap)) { - return false; - } - @SuppressWarnings("rawtypes") - IntObjectMap other = (IntObjectMap) obj; - synchronized (this) { - if (size != other.size()) { - return false; - } - for (int i = 0; i < values.length; ++i) { - V value = values[i]; - if (value != null) { - int key = keys[i]; - Object otherValue = other.get(key); - if (value == NULL_VALUE) { - if (otherValue != null) { - return false; - } - } else if (!value.equals(otherValue)) { - return false; - } - } - } - } - return true; - } - - @Override - public synchronized boolean containsKey(Object key) { - return containsKey(objectToKey(key)); - } - - @Override - public synchronized V get(Object key) { - return get(objectToKey(key)); - } - - @Override - public synchronized V put(Integer key, V value) { - return put(objectToKey(key), value); - } - - @Override - public synchronized V remove(Object key) { - return remove(objectToKey(key)); - } - - @Override - public synchronized Set keySet() { - return keySet; - } - - @Override - public synchronized Set> entrySet() { - return entrySet; - } - - private int objectToKey(Object key) { - return ((Integer) key).intValue(); - } - - /** - * Locates the index for the given key. This method probes using double hashing. - * - * @param key the key for an entry in the map. - * @return the index where the key was found, or {@code -1} if no entry is found for that key. - */ - private int indexOf(int key) { - int startIndex = hashIndex(key); - int index = startIndex; - - for (; ; ) { - if (values[index] == null) { - // It's available, so no chance that this value exists anywhere in the map. - return -1; - } - if (key == keys[index]) { - return index; - } - - // Conflict, keep probing ... - if ((index = probeNext(index)) == startIndex) { - return -1; - } - } - } - - /** Returns the hashed index for the given key. */ - private int hashIndex(int key) { - // The array lengths are always a power of two, so we can use a bitmask to stay inside the array - // bounds. - return hashCode(key) & mask; - } - - /** Returns the hash code for the key. */ - private static int hashCode(int key) { - return key; - } - - /** Get the next sequential index after {@code index} and wraps if necessary. */ - private int probeNext(int index) { - // The array lengths are always a power of two, so we can use a bitmask to stay inside the array - // bounds. - return (index + 1) & mask; - } - - /** Grows the map size after an insertion. If necessary, performs a rehash of the map. */ - private void growSize() { - size++; - - if (size > maxSize) { - if (keys.length == Integer.MAX_VALUE) { - throw new IllegalStateException("Max capacity reached at size=" + size); - } - - // Double the capacity. - rehash(keys.length << 1); - } - } - - /** - * Removes entry at the given index position. Also performs opportunistic, incremental rehashing - * if necessary to not break conflict chains. - * - * @param index the index position of the element to remove. - * @return {@code true} if the next item was moved back. {@code false} otherwise. - */ - private boolean removeAt(final int index) { - --size; - // Clearing the key is not strictly necessary (for GC like in a regular collection), - // but recommended for security. The memory location is still fresh in the cache anyway. - keys[index] = 0; - values[index] = null; - - // In the interval from index to the next available entry, the arrays may have entries - // that are displaced from their base position due to prior conflicts. Iterate these - // entries and move them back if possible, optimizing future lookups. - // Knuth Section 6.4 Algorithm R, also used by the JDK's IdentityHashMap. - - int nextFree = index; - int i = probeNext(index); - for (V value = values[i]; value != null; value = values[i = probeNext(i)]) { - int key = keys[i]; - int bucket = hashIndex(key); - if (i < bucket && (bucket <= nextFree || nextFree <= i) - || bucket <= nextFree && nextFree <= i) { - // Move the displaced entry "back" to the first available position. - keys[nextFree] = key; - values[nextFree] = value; - // Put the first entry after the displaced entry - keys[i] = 0; - values[i] = null; - nextFree = i; - } - } - return nextFree != index; - } - - /** Calculates the maximum size allowed before rehashing. */ - private int calcMaxSize(int capacity) { - // Clip the upper bound so that there will always be at least one available slot. - int upperBound = capacity - 1; - return Math.min(upperBound, (int) (capacity * loadFactor)); - } - - /** - * Rehashes the map for the given capacity. - * - * @param newCapacity the new capacity for the map. - */ - private void rehash(int newCapacity) { - int[] oldKeys = keys; - V[] oldVals = values; - - keys = new int[newCapacity]; - @SuppressWarnings({"unchecked", "SuspiciousArrayCast"}) - V[] temp = (V[]) new Object[newCapacity]; - values = temp; - - maxSize = calcMaxSize(newCapacity); - mask = newCapacity - 1; - - // Insert to the new arrays. - for (int i = 0; i < oldVals.length; ++i) { - V oldVal = oldVals[i]; - if (oldVal != null) { - // Inlined put(), but much simpler: we don't need to worry about - // duplicated keys, growing/rehashing, or failing to insert. - int oldKey = oldKeys[i]; - int index = hashIndex(oldKey); - - for (; ; ) { - if (values[index] == null) { - keys[index] = oldKey; - values[index] = oldVal; - break; - } - - // Conflict, keep probing. Can wrap around, but never reaches startIndex again. - index = probeNext(index); - } - } - } - } - - @Override - public synchronized String toString() { - if (isEmpty()) { - return "{}"; - } - StringBuilder sb = new StringBuilder(4 * size); - sb.append('{'); - boolean first = true; - for (int i = 0; i < values.length; ++i) { - V value = values[i]; - if (value != null) { - if (!first) { - sb.append(", "); - } - sb.append(keyToString(keys[i])) - .append('=') - .append(value == this ? "(this Map)" : toExternal(value)); - first = false; - } - } - return sb.append('}').toString(); - } - - /** - * Helper method called by {@link #toString()} in order to convert a single map key into a string. - * This is protected to allow subclasses to override the appearance of a given key. - */ - protected String keyToString(int key) { - return Integer.toString(key); - } - - /** Set implementation for iterating over the entries of the map. */ - private final class EntrySet extends AbstractSet> { - @Override - public Iterator> iterator() { - return new MapIterator(); - } - - @Override - public int size() { - return SynchronizedIntObjectHashMap.this.size(); - } - } - - /** Set implementation for iterating over the keys. */ - private final class KeySet extends AbstractSet { - @Override - public int size() { - return SynchronizedIntObjectHashMap.this.size(); - } - - @Override - public boolean contains(Object o) { - return SynchronizedIntObjectHashMap.this.containsKey(o); - } - - @Override - public boolean remove(Object o) { - return SynchronizedIntObjectHashMap.this.remove(o) != null; - } - - @Override - public boolean retainAll(Collection retainedKeys) { - synchronized (SynchronizedIntObjectHashMap.this) { - boolean changed = false; - for (Iterator> iter = entries().iterator(); iter.hasNext(); ) { - PrimitiveEntry entry = iter.next(); - if (!retainedKeys.contains(entry.key())) { - changed = true; - iter.remove(); - } - } - return changed; - } - } - - @Override - public void clear() { - SynchronizedIntObjectHashMap.this.clear(); - } - - @Override - public Iterator iterator() { - synchronized (SynchronizedIntObjectHashMap.this) { - final Iterator> iter = entrySet.iterator(); - return new Iterator() { - @Override - public boolean hasNext() { - synchronized (SynchronizedIntObjectHashMap.this) { - return iter.hasNext(); - } - } - - @Override - public Integer next() { - synchronized (SynchronizedIntObjectHashMap.this) { - return iter.next().getKey(); - } - } - - @Override - public void remove() { - synchronized (SynchronizedIntObjectHashMap.this) { - iter.remove(); - } - } - }; - } - } - } - - /** - * Iterator over primitive entries. Entry key/values are overwritten by each call to {@link - * #next()}. - */ - private final class PrimitiveIterator implements Iterator>, PrimitiveEntry { - private int prevIndex = -1; - private int nextIndex = -1; - private int entryIndex = -1; - - private void scanNext() { - while (++nextIndex != values.length && values[nextIndex] == null) {} - } - - @Override - public boolean hasNext() { - synchronized (SynchronizedIntObjectHashMap.this) { - if (nextIndex == -1) { - scanNext(); - } - return nextIndex != values.length; - } - } - - @Override - public PrimitiveEntry next() { - synchronized (SynchronizedIntObjectHashMap.this) { - if (!hasNext()) { - throw new NoSuchElementException(); - } - - prevIndex = nextIndex; - scanNext(); - - // Always return the same Entry object, just change its index each time. - entryIndex = prevIndex; - return this; - } - } - - @Override - public void remove() { - synchronized (SynchronizedIntObjectHashMap.this) { - if (prevIndex == -1) { - throw new IllegalStateException("next must be called before each remove."); - } - if (removeAt(prevIndex)) { - // removeAt may move elements "back" in the array if they have been displaced because - // their - // spot in the - // array was occupied when they were inserted. If this occurs then the nextIndex is now - // invalid and - // should instead point to the prevIndex which now holds an element which was "moved - // back". - nextIndex = prevIndex; - } - prevIndex = -1; - } - } - - // Entry implementation. Since this implementation uses a single Entry, we coalesce that - // into the Iterator object (potentially making loop optimization much easier). - - @Override - public int key() { - synchronized (SynchronizedIntObjectHashMap.this) { - return keys[entryIndex]; - } - } - - @Override - public V value() { - synchronized (SynchronizedIntObjectHashMap.this) { - return toExternal(values[entryIndex]); - } - } - - @Override - public void setValue(V value) { - synchronized (SynchronizedIntObjectHashMap.this) { - values[entryIndex] = toInternal(value); - } - } - } - - /** Iterator used by the {@link Map} interface. */ - private final class MapIterator implements Iterator> { - private final PrimitiveIterator iter = new PrimitiveIterator(); - - @Override - public boolean hasNext() { - synchronized (SynchronizedIntObjectHashMap.this) { - return iter.hasNext(); - } - } - - @Override - public Entry next() { - synchronized (SynchronizedIntObjectHashMap.this) { - if (!hasNext()) { - throw new NoSuchElementException(); - } - - iter.next(); - - return new MapEntry(iter.entryIndex); - } - } - - @Override - public void remove() { - synchronized (SynchronizedIntObjectHashMap.this) { - iter.remove(); - } - } - } - - /** A single entry in the map. */ - final class MapEntry implements Entry { - private final int entryIndex; - - MapEntry(int entryIndex) { - this.entryIndex = entryIndex; - } - - @Override - public Integer getKey() { - synchronized (SynchronizedIntObjectHashMap.this) { - verifyExists(); - return keys[entryIndex]; - } - } - - @Override - public V getValue() { - synchronized (SynchronizedIntObjectHashMap.this) { - verifyExists(); - return toExternal(values[entryIndex]); - } - } - - @Override - public V setValue(V value) { - synchronized (SynchronizedIntObjectHashMap.this) { - verifyExists(); - V prevValue = toExternal(values[entryIndex]); - values[entryIndex] = toInternal(value); - return prevValue; - } - } - - private void verifyExists() { - if (values[entryIndex] == null) { - throw new IllegalStateException("The map entry has been removed"); - } - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java index 4cf33fa86..12e0b60dc 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java @@ -115,7 +115,7 @@ public static Payload create(ByteBuf data, @Nullable ByteBuf metadata) { ByteBufPayload payload = RECYCLER.get(); payload.data = data; payload.metadata = metadata; - // unsure data and metadata is set before refCnt change + // ensure data and metadata is set before refCnt change payload.setRefCnt(1); return payload; } diff --git a/rsocket-core/src/test/java/io/rsocket/FrameAssert.java b/rsocket-core/src/test/java/io/rsocket/FrameAssert.java new file mode 100644 index 000000000..b5b1e2ec9 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/FrameAssert.java @@ -0,0 +1,336 @@ +package io.rsocket; + +import static org.assertj.core.error.ShouldBe.shouldBe; +import static org.assertj.core.error.ShouldBeEqual.shouldBeEqual; +import static org.assertj.core.error.ShouldHave.shouldHave; +import static org.assertj.core.error.ShouldNotHave.shouldNotHave; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.frame.*; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; +import org.assertj.core.api.AbstractAssert; +import org.assertj.core.api.Condition; +import org.assertj.core.error.BasicErrorMessageFactory; +import org.assertj.core.internal.Failures; +import org.assertj.core.internal.Objects; +import reactor.util.annotation.Nullable; + +public class FrameAssert extends AbstractAssert { + public static FrameAssert assertThat(@Nullable ByteBuf frame) { + return new FrameAssert(frame); + } + + private final Failures failures = Failures.instance(); + + public FrameAssert(@Nullable ByteBuf frame) { + super(frame, FrameAssert.class); + } + + public FrameAssert hasMetadata() { + assertValid(); + + if (!FrameHeaderCodec.hasMetadata(actual)) { + throw failures.failure(info, shouldHave(actual, new Condition<>("metadata present"))); + } + + return this; + } + + public FrameAssert hasNoMetadata() { + assertValid(); + + if (FrameHeaderCodec.hasMetadata(actual)) { + throw failures.failure(info, shouldHave(actual, new Condition<>("metadata absent"))); + } + + return this; + } + + public FrameAssert hasMetadata(String metadata, Charset charset) { + return hasMetadata(metadata.getBytes(charset)); + } + + public FrameAssert hasMetadata(String metadataUtf8) { + return hasMetadata(metadataUtf8, CharsetUtil.UTF_8); + } + + public FrameAssert hasMetadata(byte[] metadata) { + return hasMetadata(Unpooled.wrappedBuffer(metadata)); + } + + public FrameAssert hasMetadata(ByteBuf metadata) { + hasMetadata(); + + final FrameType frameType = FrameHeaderCodec.frameType(actual); + ByteBuf content; + if (frameType == FrameType.METADATA_PUSH) { + content = MetadataPushFrameCodec.metadata(actual); + } else if (frameType.hasInitialRequestN()) { + content = RequestStreamFrameCodec.metadata(actual); + } else { + content = PayloadFrameCodec.metadata(actual); + } + + if (!ByteBufUtil.equals(content, metadata)) { + throw failures.failure(info, shouldBeEqual(content, metadata, new ByteBufRepresentation())); + } + + return this; + } + + public FrameAssert hasData(String dataUtf8) { + return hasData(dataUtf8, CharsetUtil.UTF_8); + } + + public FrameAssert hasData(String data, Charset charset) { + return hasData(data.getBytes(charset)); + } + + public FrameAssert hasData(byte[] data) { + return hasData(Unpooled.wrappedBuffer(data)); + } + + public FrameAssert hasData(ByteBuf data) { + assertValid(); + + ByteBuf content; + final FrameType frameType = FrameHeaderCodec.frameType(actual); + if (!frameType.canHaveData()) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have data content but frame type %n<%s> does not support data content", + actual, frameType)); + } else if (frameType.hasInitialRequestN()) { + content = RequestStreamFrameCodec.data(actual); + } else if (frameType == FrameType.ERROR) { + content = ErrorFrameCodec.data(actual); + } else { + content = PayloadFrameCodec.data(actual); + } + + if (!ByteBufUtil.equals(content, data)) { + throw failures.failure(info, shouldBeEqual(content, data, new ByteBufRepresentation())); + } + + return this; + } + + public FrameAssert hasFragmentsFollow() { + return hasFollows(true); + } + + public FrameAssert hasNoFragmentsFollow() { + return hasFollows(false); + } + + public FrameAssert hasFollows(boolean hasFollows) { + assertValid(); + + if (FrameHeaderCodec.hasFollows(actual) != hasFollows) { + throw failures.failure( + info, + hasFollows + ? shouldHave(actual, new Condition<>("follows fragment present")) + : shouldNotHave(actual, new Condition<>("follows fragment present"))); + } + + return this; + } + + public FrameAssert typeOf(FrameType frameType) { + assertValid(); + + final FrameType currentFrameType = FrameHeaderCodec.frameType(actual); + if (currentFrameType != frameType) { + throw failures.failure( + info, shouldBe(currentFrameType, new Condition<>("frame of type [" + frameType + "]"))); + } + + return this; + } + + public FrameAssert hasStreamId(int streamId) { + assertValid(); + + final int currentStreamId = FrameHeaderCodec.streamId(actual); + if (currentStreamId != streamId) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting streamId:%n<%s>%n to be equal %n<%s>", currentStreamId, streamId)); + } + + return this; + } + + public FrameAssert hasStreamIdZero() { + return hasStreamId(0); + } + + public FrameAssert hasClientSideStreamId() { + assertValid(); + + final int currentStreamId = FrameHeaderCodec.streamId(actual); + if (currentStreamId % 2 != 1) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting Client Side StreamId %nbut was " + + (currentStreamId == 0 ? "Stream Id 0" : "Server Side Stream Id"))); + } + + return this; + } + + public FrameAssert hasServerSideStreamId() { + assertValid(); + + final int currentStreamId = FrameHeaderCodec.streamId(actual); + if (currentStreamId == 0 || currentStreamId % 2 != 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting %n Server Side Stream Id %nbut was %n " + + (currentStreamId == 0 ? "Stream Id 0" : "Client Side Stream Id"))); + } + + return this; + } + + public FrameAssert hasPayloadSize(int payloadLength) { + assertValid(); + + final FrameType currentFrameType = FrameHeaderCodec.frameType(actual); + + final int currentFrameLength = + actual.readableBytes() + - FrameHeaderCodec.size() + - (FrameHeaderCodec.hasMetadata(actual) && currentFrameType.canHaveData() ? 3 : 0) + - (currentFrameType.hasInitialRequestN() ? Integer.BYTES : 0); + if (currentFrameLength != payloadLength) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting %n<%s> %nframe payload size to be equal to %n<%s> %nbut was %n<%s>", + actual, payloadLength, currentFrameLength)); + } + + return this; + } + + public FrameAssert hasRequestN(int n) { + assertValid(); + + final FrameType currentFrameType = FrameHeaderCodec.frameType(actual); + long requestN; + if (currentFrameType.hasInitialRequestN()) { + requestN = RequestStreamFrameCodec.initialRequestN(actual); + } else if (currentFrameType == FrameType.REQUEST_N) { + requestN = RequestNFrameCodec.requestN(actual); + } else { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have requestN but frame type %n<%s> does not support requestN", + actual, currentFrameType)); + } + + if ((requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : requestN) != n) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have %nrequestN(<%s>) but got %nrequestN(<%s>)", + actual, n, requestN)); + } + + return this; + } + + public FrameAssert hasPayload(Payload expectedPayload) { + assertValid(); + + List failedExpectation = new ArrayList<>(); + FrameType frameType = FrameHeaderCodec.frameType(actual); + boolean hasMetadata = FrameHeaderCodec.hasMetadata(actual); + if (expectedPayload.hasMetadata() != hasMetadata) { + failedExpectation.add( + String.format( + "hasMetadata(%s) but actual was hasMetadata(%s)%n", + expectedPayload.hasMetadata(), hasMetadata)); + } else if (hasMetadata) { + ByteBuf metadataContent; + if (frameType == FrameType.METADATA_PUSH) { + metadataContent = MetadataPushFrameCodec.metadata(actual); + } else if (frameType.hasInitialRequestN()) { + metadataContent = RequestStreamFrameCodec.metadata(actual); + } else { + metadataContent = PayloadFrameCodec.metadata(actual); + } + if (!ByteBufUtil.equals(expectedPayload.sliceMetadata(), metadataContent)) { + failedExpectation.add( + String.format( + "metadata(%s) but actual was metadata(%s)%n", + expectedPayload.sliceMetadata(), metadataContent)); + } + } + + ByteBuf dataContent; + if (!frameType.canHaveData() && expectedPayload.sliceData().readableBytes() > 0) { + failedExpectation.add( + String.format( + "data(%s) but frame type %n<%s> does not support data", actual, frameType)); + } else { + if (frameType.hasInitialRequestN()) { + dataContent = RequestStreamFrameCodec.data(actual); + } else { + dataContent = PayloadFrameCodec.data(actual); + } + + if (!ByteBufUtil.equals(expectedPayload.sliceData(), dataContent)) { + failedExpectation.add( + String.format( + "data(%s) but actual was data(%s)%n", expectedPayload.sliceData(), dataContent)); + } + } + + if (!failedExpectation.isEmpty()) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting be equal to the given payload but the following differences were found" + + " %s", + failedExpectation)); + } + + return this; + } + + public void hasNoLeaks() { + if (!actual.release() || actual.refCnt() > 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have refCnt(0) after release but " + + "actual was " + + "%n", + actual, actual.refCnt())); + } + } + + private void assertValid() { + Objects.instance().assertNotNull(info, actual); + + try { + FrameHeaderCodec.frameType(actual); + } catch (Throwable t) { + throw failures.failure( + info, shouldBe(actual, new Condition<>("a valid frame, but got exception [" + t + "]"))); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/PayloadAssert.java b/rsocket-core/src/test/java/io/rsocket/PayloadAssert.java new file mode 100755 index 000000000..847f24722 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/PayloadAssert.java @@ -0,0 +1,180 @@ +package io.rsocket; + +import static org.assertj.core.error.ShouldBeEqual.shouldBeEqual; +import static org.assertj.core.error.ShouldHave.shouldHave; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.frame.ByteBufRepresentation; +import io.rsocket.util.DefaultPayload; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; +import org.assertj.core.api.AbstractAssert; +import org.assertj.core.api.Condition; +import org.assertj.core.error.BasicErrorMessageFactory; +import org.assertj.core.internal.Failures; +import org.assertj.core.internal.Objects; +import reactor.util.annotation.Nullable; + +public class PayloadAssert extends AbstractAssert { + + public static PayloadAssert assertThat(@Nullable Payload payload) { + return new PayloadAssert(payload); + } + + private final Failures failures = Failures.instance(); + + public PayloadAssert(@Nullable Payload payload) { + super(payload, PayloadAssert.class); + } + + public PayloadAssert hasMetadata() { + assertValid(); + + if (!actual.hasMetadata()) { + throw failures.failure(info, shouldHave(actual, new Condition<>("metadata present"))); + } + + return this; + } + + public PayloadAssert hasNoMetadata() { + assertValid(); + + if (actual.hasMetadata()) { + throw failures.failure(info, shouldHave(actual, new Condition<>("metadata absent"))); + } + + return this; + } + + public PayloadAssert hasMetadata(String metadata, Charset charset) { + return hasMetadata(metadata.getBytes(charset)); + } + + public PayloadAssert hasMetadata(String metadataUtf8) { + return hasMetadata(metadataUtf8, CharsetUtil.UTF_8); + } + + public PayloadAssert hasMetadata(byte[] metadata) { + return hasMetadata(Unpooled.wrappedBuffer(metadata)); + } + + public PayloadAssert hasMetadata(ByteBuf metadata) { + hasMetadata(); + + ByteBuf content = actual.sliceMetadata(); + if (!ByteBufUtil.equals(content, metadata)) { + throw failures.failure(info, shouldBeEqual(content, metadata, new ByteBufRepresentation())); + } + + return this; + } + + public PayloadAssert hasData(String dataUtf8) { + return hasData(dataUtf8, CharsetUtil.UTF_8); + } + + public PayloadAssert hasData(String data, Charset charset) { + return hasData(data.getBytes(charset)); + } + + public PayloadAssert hasData(byte[] data) { + return hasData(Unpooled.wrappedBuffer(data)); + } + + public PayloadAssert hasData(ByteBuf data) { + assertValid(); + + ByteBuf content = actual.sliceData(); + if (!ByteBufUtil.equals(content, data)) { + throw failures.failure(info, shouldBeEqual(content, data, new ByteBufRepresentation())); + } + + return this; + } + + public void hasNoLeaks() { + if (!(actual instanceof DefaultPayload)) { + if (actual.refCnt() == 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have refCnt(0) after release but " + + "actual was already released", + actual, actual.refCnt())); + } + if (!actual.release() || actual.refCnt() > 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have refCnt(0) after release but " + + "actual was " + + "%n", + actual, actual.refCnt())); + } + } + } + + public void isReleased() { + if (actual.refCnt() > 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have refCnt(0) but " + "actual was " + "%n", + actual, actual.refCnt())); + } + } + + @Override + public PayloadAssert isEqualTo(Object expected) { + if (expected instanceof Payload) { + if (expected == actual) { + return this; + } + + Payload expectedPayload = (Payload) expected; + List failedExpectation = new ArrayList<>(); + if (expectedPayload.hasMetadata() != actual.hasMetadata()) { + failedExpectation.add( + String.format( + "hasMetadata(%s) but actual was hasMetadata(%s)%n", + expectedPayload.hasMetadata(), actual.hasMetadata())); + } else { + if (!ByteBufUtil.equals(expectedPayload.sliceMetadata(), actual.sliceMetadata())) { + failedExpectation.add( + String.format( + "metadata(%s) but actual was metadata(%s)%n", + expectedPayload.sliceMetadata(), actual.sliceMetadata())); + } + } + + if (!ByteBufUtil.equals(expectedPayload.sliceData(), actual.sliceData())) { + failedExpectation.add( + String.format( + "data(%s) but actual was data(%s)%n", + expectedPayload.sliceData(), actual.sliceData())); + } + + if (!failedExpectation.isEmpty()) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting be equal to the given one but the following differences were found" + + " %s", + failedExpectation)); + } + + return this; + } + + return super.isEqualTo(expected); + } + + private void assertValid() { + Objects.instance().assertNotNull(info, actual); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java b/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java index 7398548be..e6f5722ac 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java +++ b/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java @@ -42,7 +42,6 @@ public Statement apply(final Statement base, Description description) { @Override public void evaluate() throws Throwable { allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); - connection = new TestDuplexConnection(allocator); connectSub = TestSubscriber.create(); init(); base.evaluate(); @@ -51,6 +50,13 @@ public void evaluate() throws Throwable { } protected void init() { + if (socket != null) { + socket.dispose(); + } + if (connection != null) { + connection.dispose(); + } + connection = new TestDuplexConnection(allocator); socket = newRSocket(); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java b/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java index 8d1d292c6..8be709ac0 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java +++ b/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java @@ -27,7 +27,6 @@ import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.RSocketClient; -import io.rsocket.TestScheduler; import io.rsocket.frame.ErrorFrameCodec; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; @@ -144,12 +143,26 @@ public void shouldSentFrameOnResolution( }) .then(testPublisher::complete) .then( - () -> + () -> { + if (requestType == FrameType.REQUEST_CHANNEL) { + Assertions.assertThat(rule.connection.getSent()) + .hasSize(2) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .matches(ReferenceCounted::release); + + Assertions.assertThat(rule.connection.getSent()) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(FrameType.COMPLETE)) + .matches(ReferenceCounted::release); + } else { Assertions.assertThat(rule.connection.getSent()) .hasSize(1) .first() .matches(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) - .matches(ReferenceCounted::release)) + .matches(ReferenceCounted::release); + } + }) .then( () -> { if (requestType != FrameType.REQUEST_FNF && requestType != FrameType.METADATA_PUSH) { @@ -395,10 +408,30 @@ public void shouldSupportMultiSubscriptionOnTheSameInteractionPublisher( assertSubscriber.await(Duration.ofSeconds(10)).assertComplete(); - Collection sent = rule.connection.getSent(); - Assertions.assertThat(sent) - .allMatch(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) - .allMatch(ReferenceCounted::release); + if (requestType == FrameType.REQUEST_CHANNEL) { + ArrayList sent = new ArrayList<>(rule.connection.getSent()); + Assertions.assertThat(sent).hasSize(4); + for (int i = 0; i < sent.size(); i++) { + if (i % 2 == 0) { + Assertions.assertThat(sent.get(i)) + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(sent.get(i)) + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(FrameType.COMPLETE)) + .matches(ReferenceCounted::release); + } + } + } else { + Collection sent = rule.connection.getSent(); + Assertions.assertThat(sent) + .hasSize( + requestType == FrameType.REQUEST_FNF || requestType == FrameType.METADATA_PUSH + ? 1 + : 2) + .allMatch(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .allMatch(ReferenceCounted::release); + } rule.allocator.assertHasNoLeaks(); } @@ -509,9 +542,9 @@ protected RSocketRequester newRSocket() { maxFrameLength, Integer.MAX_VALUE, Integer.MAX_VALUE, + Integer.MAX_VALUE, null, - RequesterLeaseHandler.None, - TestScheduler.INSTANCE); + RequesterLeaseHandler.None); } public int getStreamIdForRequestType(FrameType expectedFrameType) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java b/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java new file mode 100644 index 000000000..cb5044e17 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java @@ -0,0 +1,404 @@ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.TestRequesterResponderSupport.genericPayload; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import java.util.Arrays; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.test.StepVerifier; +import reactor.test.util.RaceTestUtils; + +public class FireAndForgetRequesterMonoTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /** + * General StateMachine transition test. No Fragmentation enabled In this test we check that the + * given instance of FireAndForgetMono subscribes, and then sends frame immediately + */ + @ParameterizedTest + @MethodSource("frameSent") + public void frameShouldBeSentOnSubscription(Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final Payload payload = genericPayload(activeStreams.getAllocator()); + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + // should not add anything to map + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + + final ByteBuf frame = activeStreams.getSendProcessor().poll(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .typeOf(FrameType.REQUEST_FNF) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(activeStreams.getSendProcessor().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + + /** + * General StateMachine transition test. Fragmentation enabled In this test we check that the + * given instance of FireAndForgetMono subscribes, and then sends all fragments as a separate + * frame immediately + */ + @ParameterizedTest + @MethodSource("frameSent") + public void frameFragmentsShouldBeSentOnSubscription( + Consumer monoConsumer) { + final int mtu = 64; + final TestRequesterResponderSupport streamManager = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final UnboundedProcessor sender = streamManager.getSendProcessor(); + + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + streamManager.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + // should not add anything to map + streamManager.assertNoActiveStreams(); + stateAssert.isTerminated(); + + Assertions.assertThat(payload.refCnt()).isZero(); + + final ByteBuf frameFragment1 = sender.poll(); + FrameAssert.assertThat(frameFragment1) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET_WITH_METADATA) // 64 - 6 (frame headers) - 3 (encoded metadata + // length) - 3 frame length + .hasMetadata(Arrays.copyOf(metadata, 52)) + .hasData(Unpooled.EMPTY_BUFFER) + .hasFragmentsFollow() + .typeOf(FrameType.REQUEST_FNF) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment2 = sender.poll(); + FrameAssert.assertThat(frameFragment2) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET_WITH_METADATA) // 64 - 6 (frame headers) - 3 (encoded metadata + // length) - 3 frame length + .hasMetadata(Arrays.copyOfRange(metadata, 52, 65)) + .hasData(Arrays.copyOf(data, 39)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment3 = sender.poll(); + FrameAssert.assertThat(frameFragment3) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET) // 64 - 6 (frame headers) - 3 frame length (no metadata - no length) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 39, 94)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment4 = sender.poll(); + FrameAssert.assertThat(frameFragment4) + .isNotNull() + .hasPayloadSize(35) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 94, 129)) + .hasNoFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + static Stream> frameSent() { + return Stream.of( + (s) -> StepVerifier.create(s).expectSubscription().expectComplete().verify(), + FireAndForgetRequesterMono::block); + } + + /** + * RefCnt validation test. Should send error if RefCnt is incorrect and frame has already been + * released Note: ONCE state should be 0 + */ + @ParameterizedTest + @MethodSource("shouldErrorOnIncorrectRefCntInGivenPayloadSource") + public void shouldErrorOnIncorrectRefCntInGivenPayload( + Consumer monoConsumer) { + final TestRequesterResponderSupport streamManager = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final UnboundedProcessor sender = streamManager.getSendProcessor(); + final Payload payload = ByteBufPayload.create(""); + payload.release(); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + streamManager.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + stateAssert.isTerminated(); + streamManager.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + static Stream> + shouldErrorOnIncorrectRefCntInGivenPayloadSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .expectError(IllegalReferenceCountException.class) + .verify(), + fireAndForgetRequesterMono -> + Assertions.assertThatThrownBy(fireAndForgetRequesterMono::block) + .isInstanceOf(IllegalReferenceCountException.class)); + } + + /** + * Check that proper payload size validation is enabled so in case payload fragmentation is + * disabled we will not send anything bigger that 16MB (see specification for MAX frame size) + */ + @ParameterizedTest + @MethodSource("shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource") + public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( + Consumer monoConsumer) { + final TestRequesterResponderSupport streamManager = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final UnboundedProcessor sender = streamManager.getSendProcessor(); + + final byte[] metadata = new byte[FRAME_LENGTH_MASK]; + final byte[] data = new byte[FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + streamManager.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + + stateAssert.isTerminated(); + streamManager.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + static Stream> + shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)) + .verify(), + fireAndForgetRequesterMono -> + Assertions.assertThatThrownBy(fireAndForgetRequesterMono::block) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)); + } + + /** + * Ensures that frame will not be sent if we dont have availability for that. Options: 1. RSocket + * disposed / Connection Error, so all racing on existing interactions should be terminated as + * well 2. RSocket tries to use lease and end-ups with no available leases + */ + @ParameterizedTest + @MethodSource("shouldErrorIfNoAvailabilitySource") + public void shouldErrorIfNoAvailability(Consumer monoConsumer) { + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(new RuntimeException("test")); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final UnboundedProcessor sender = streamManager.getSendProcessor(); + final Payload payload = genericPayload(allocator); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + streamManager.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + + stateAssert.isTerminated(); + streamManager.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + static Stream> shouldErrorIfNoAvailabilitySource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)) + .verify(), + fireAndForgetRequesterMono -> + Assertions.assertThatThrownBy(fireAndForgetRequesterMono::block) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)); + } + + /** Ensures single subscription happens in case of racing */ + @Test + public void shouldSubscribeExactlyOnce1() { + final TestRequesterResponderSupport streamManager = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final UnboundedProcessor sender = streamManager.getSendProcessor(); + + for (int i = 1; i < 50000; i += 2) { + final Payload payload = ByteBufPayload.create("testData", "testMetadata"); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + Assertions.assertThatThrownBy( + () -> + RaceTestUtils.race( + () -> { + AtomicReference atomicReference = new AtomicReference(); + fireAndForgetRequesterMono.subscribe(null, atomicReference::set); + Throwable throwable = atomicReference.get(); + if (throwable != null) { + throw Exceptions.propagate(throwable); + } + }, + fireAndForgetRequesterMono::block)) + .matches( + t -> { + Assertions.assertThat(t) + .hasMessageContaining("FireAndForgetMono allows only a single Subscriber"); + return true; + }); + + final ByteBuf frame = sender.poll(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .typeOf(FrameType.REQUEST_FNF) + .hasClientSideStreamId() + .hasStreamId(i) + .hasNoLeaks(); + + stateAssert.isTerminated(); + streamManager.assertNoActiveStreams(); + } + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + @Test + public void checkName() { + final TestRequesterResponderSupport testRequesterResponderSupport = + TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); + final UnboundedProcessor sender = testRequesterResponderSupport.getSendProcessor(); + final Payload payload = ByteBufPayload.create("testData", "testMetadata"); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, testRequesterResponderSupport); + + Assertions.assertThat(Scannable.from(fireAndForgetRequesterMono).name()) + .isEqualTo("source(FireAndForgetMono)"); + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java index 209bc3810..78f7bff66 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java @@ -24,7 +24,6 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.rsocket.RSocket; -import io.rsocket.TestScheduler; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.ConnectionErrorException; import io.rsocket.frame.FrameHeaderCodec; @@ -74,11 +73,11 @@ static RSocketState requester(int tickPeriod, int timeout) { StreamIdSupplier.clientSupplier(), 0, FRAME_LENGTH_MASK, + Integer.MAX_VALUE, tickPeriod, timeout, new DefaultKeepAliveHandler(connection), - RequesterLeaseHandler.None, - TestScheduler.INSTANCE); + RequesterLeaseHandler.None); return new RSocketState(rSocket, allocator, connection); } @@ -101,11 +100,11 @@ static ResumableRSocketState resumableRequester(int tickPeriod, int timeout) { StreamIdSupplier.clientSupplier(), 0, FRAME_LENGTH_MASK, + Integer.MAX_VALUE, tickPeriod, timeout, new ResumableKeepAliveHandler(resumableConnection), - RequesterLeaseHandler.None, - TestScheduler.INSTANCE); + RequesterLeaseHandler.None); return new ResumableRSocketState(rSocket, connection, resumableConnection, allocator); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java index 1d93d9388..707d42afe 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java @@ -1,5 +1,7 @@ package io.rsocket.core; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_SIZE; + import io.rsocket.Payload; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameLengthCodec; @@ -12,26 +14,45 @@ class PayloadValidationUtilsTest { @Test void shouldBeValidFrameWithNoFragmentation() { + int maxFrameLength = + ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); + byte[] data = new byte[maxFrameLength - FRAME_LENGTH_SIZE - FrameHeaderCodec.size()]; + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isFalse(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation1() { int maxFrameLength = ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); byte[] data = - new byte[maxFrameLength - FrameLengthCodec.FRAME_LENGTH_SIZE - FrameHeaderCodec.size()]; + new byte[maxFrameLength - FRAME_LENGTH_SIZE - Integer.BYTES - FrameHeaderCodec.size()]; ThreadLocalRandom.current().nextBytes(data); final Payload payload = DefaultPayload.create(data); - Assertions.assertThat(PayloadValidationUtils.isValid(0, payload, maxFrameLength)).isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isTrue(); } @Test void shouldBeInValidFrameWithNoFragmentation() { int maxFrameLength = ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); - byte[] data = - new byte[maxFrameLength - FrameLengthCodec.FRAME_LENGTH_SIZE - FrameHeaderCodec.size() + 1]; + byte[] data = new byte[maxFrameLength - FRAME_LENGTH_SIZE - FrameHeaderCodec.size() + 1]; ThreadLocalRandom.current().nextBytes(data); final Payload payload = DefaultPayload.create(data); - Assertions.assertThat(PayloadValidationUtils.isValid(0, payload, maxFrameLength)).isFalse(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isFalse(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isFalse(); } @Test @@ -41,15 +62,18 @@ void shouldBeValidFrameWithNoFragmentation0() { byte[] metadata = new byte[maxFrameLength / 2]; byte[] data = new byte - [maxFrameLength / 2 - - FrameLengthCodec.FRAME_LENGTH_SIZE + [(maxFrameLength / 2 + 1) + - FRAME_LENGTH_SIZE - FrameHeaderCodec.size() - FrameHeaderCodec.size()]; ThreadLocalRandom.current().nextBytes(data); ThreadLocalRandom.current().nextBytes(metadata); final Payload payload = DefaultPayload.create(data, metadata); - Assertions.assertThat(PayloadValidationUtils.isValid(0, payload, maxFrameLength)).isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isFalse(); } @Test @@ -62,7 +86,10 @@ void shouldBeInValidFrameWithNoFragmentation1() { ThreadLocalRandom.current().nextBytes(data); final Payload payload = DefaultPayload.create(data, metadata); - Assertions.assertThat(PayloadValidationUtils.isValid(0, payload, maxFrameLength)).isFalse(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isFalse(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isFalse(); } @Test @@ -75,7 +102,10 @@ void shouldBeValidFrameWithNoFragmentation2() { ThreadLocalRandom.current().nextBytes(data); final Payload payload = DefaultPayload.create(data, metadata); - Assertions.assertThat(PayloadValidationUtils.isValid(0, payload, maxFrameLength)).isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isTrue(); } @Test @@ -88,7 +118,10 @@ void shouldBeValidFrameWithNoFragmentation3() { ThreadLocalRandom.current().nextBytes(data); final Payload payload = DefaultPayload.create(data, metadata); - Assertions.assertThat(PayloadValidationUtils.isValid(64, payload, maxFrameLength)).isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(64, maxFrameLength, payload, true)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(64, maxFrameLength, payload, false)) + .isTrue(); } @Test @@ -101,6 +134,9 @@ void shouldBeValidFrameWithNoFragmentation4() { ThreadLocalRandom.current().nextBytes(data); final Payload payload = DefaultPayload.create(data, metadata); - Assertions.assertThat(PayloadValidationUtils.isValid(64, payload, maxFrameLength)).isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(64, maxFrameLength, payload, true)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(64, maxFrameLength, payload, false)) + .isTrue(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java index 7faef600a..32bae9270 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java @@ -27,10 +27,10 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.TestScheduler; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.Exceptions; import io.rsocket.frame.FrameHeaderCodec; @@ -104,11 +104,11 @@ void setUp() { StreamIdSupplier.clientSupplier(), 0, FRAME_LENGTH_MASK, + Integer.MAX_VALUE, 0, 0, null, - requesterLeaseHandler, - TestScheduler.INSTANCE); + requesterLeaseHandler); mockRSocketHandler = mock(RSocket.class); when(mockRSocketHandler.metadataPush(any())) @@ -155,7 +155,8 @@ void setUp() { payloadDecoder, responderLeaseHandler, 0, - FRAME_LENGTH_MASK); + FRAME_LENGTH_MASK, + Integer.MAX_VALUE); } @Test @@ -235,10 +236,23 @@ void requesterPresentLeaseRequestsAreAccepted( .expectComplete() .verify(Duration.ofSeconds(5)); - Assertions.assertThat(connection.getSent()) - .hasSize(1) - .first() - .matches(ReferenceCounted::release); + if (frameType == REQUEST_CHANNEL) { + Assertions.assertThat(connection.getSent()) + .hasSize(2) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == frameType) + .matches(ReferenceCounted::release); + Assertions.assertThat(connection.getSent()) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == frameType) + .matches(ReferenceCounted::release); + } Assertions.assertThat(rSocketRequester.availability()).isCloseTo(0.5, offset(1e-2)); @@ -279,11 +293,24 @@ void requesterDepletedAllowedLeaseRequestsAreRejected( // ensures availability is changed and lease is used only up on frame sending Assertions.assertThat(rSocketRequester.availability()).isCloseTo(0.0, offset(1e-2)); - Assertions.assertThat(connection.getSent()) - .hasSize(1) - .first() - .matches(bb -> FrameHeaderCodec.frameType(bb) == interactionType) - .matches(ReferenceCounted::release); + + if (interactionType == REQUEST_CHANNEL) { + Assertions.assertThat(connection.getSent()) + .hasSize(2) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == interactionType) + .matches(ReferenceCounted::release); + Assertions.assertThat(connection.getSent()) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == interactionType) + .matches(ReferenceCounted::release); + } ByteBuf buffer2 = byteBufAllocator.buffer(); buffer2.writeCharSequence("test", CharsetUtil.UTF_8); @@ -478,6 +505,8 @@ void receiveLease() { Assertions.assertThat(receivedLease.getTimeToLiveMillis()).isEqualTo(ttl); Assertions.assertThat(receivedLease.getStartingAllowedRequests()).isEqualTo(numberOfRequests); Assertions.assertThat(receivedLease.getMetadata().toString(utf8)).isEqualTo(metadataContent); + + ReferenceCountUtil.safeRelease(leaseFrame); } ByteBuf leaseFrame(int ttl, int requests, ByteBuf metadata) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java index 9ecdd13ba..34810b6bd 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java @@ -25,15 +25,12 @@ import java.util.Iterator; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; import org.mockito.Mockito; import reactor.core.Exceptions; import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; import reactor.util.retry.Retry; public class RSocketReconnectTest { @@ -42,14 +39,6 @@ public class RSocketReconnectTest { @Test public void shouldBeASharedReconnectableInstanceOfRSocketMono() throws InterruptedException { - CountDownLatch latch = new CountDownLatch(1); - Schedulers.onScheduleHook( - "test", - r -> - () -> { - r.run(); - latch.countDown(); - }); TestClientTransport[] testClientTransport = new TestClientTransport[] {new TestClientTransport()}; Mono rSocketMono = @@ -63,10 +52,8 @@ public void shouldBeASharedReconnectableInstanceOfRSocketMono() throws Interrupt Assertions.assertThat(rSocket1).isEqualTo(rSocket2); testClientTransport[0].testConnection().dispose(); - Assertions.assertThat(latch.await(5, TimeUnit.SECONDS)).isTrue(); testClientTransport[0] = new TestClientTransport(); - System.out.println("here"); RSocket rSocket3 = rSocketMono.block(); RSocket rSocket4 = rSocketMono.block(); diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java index 5949d9ada..64a8ec30c 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java @@ -20,8 +20,8 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.CharsetUtil; import io.rsocket.RSocket; -import io.rsocket.TestScheduler; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; @@ -40,9 +40,11 @@ import org.assertj.core.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import reactor.test.util.RaceTestUtils; class RSocketRequesterSubscribersTest { @@ -71,11 +73,11 @@ void setUp() { StreamIdSupplier.clientSupplier(), 0, FRAME_LENGTH_MASK, + Integer.MAX_VALUE, 0, 0, null, - RequesterLeaseHandler.None, - TestScheduler.INSTANCE); + RequesterLeaseHandler.None); } @ParameterizedTest @@ -99,7 +101,8 @@ void singleSubscriber(Function> interaction) { @ParameterizedTest @MethodSource("allInteractions") - void singleSubscriberInCaseOfRacing(Function> interaction) { + void singleSubscriberInCaseOfRacing( + Function> interaction, FrameType requestType) { for (int i = 1; i < 20000; i += 2) { Flux response = Flux.from(interaction.apply(rSocketRequester)); AssertSubscriber assertSubscriberA = AssertSubscriber.create(); @@ -116,12 +119,23 @@ void singleSubscriberInCaseOfRacing(Function> interaction) Assertions.assertThat(new AssertSubscriber[] {assertSubscriberA, assertSubscriberB}) .anySatisfy(as -> as.assertError(IllegalStateException.class)); - Assertions.assertThat(connection.getSent()) - .hasSize(1) - .first() - .matches(bb -> REQUEST_TYPES.contains(FrameHeaderCodec.frameType(bb))) - .matches(ByteBuf::release); - + if (requestType == FrameType.REQUEST_CHANNEL) { + Assertions.assertThat(connection.getSent()) + .hasSize(2) + .first() + .matches(bb -> REQUEST_TYPES.contains(FrameHeaderCodec.frameType(bb))) + .matches(ByteBuf::release); + Assertions.assertThat(connection.getSent()) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == FrameType.COMPLETE) + .matches(ByteBuf::release); + } else { + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> REQUEST_TYPES.contains(FrameHeaderCodec.frameType(bb))) + .matches(ByteBuf::release); + } connection.clearSendReceiveBuffers(); } } @@ -141,12 +155,29 @@ static long requestFramesCount(Collection frames) { .count(); } - static Stream>> allInteractions() { + static Stream allInteractions() { return Stream.of( - rSocket -> rSocket.fireAndForget(DefaultPayload.create("test")), - rSocket -> rSocket.requestResponse(DefaultPayload.create("test")), - rSocket -> rSocket.requestStream(DefaultPayload.create("test")), - // rSocket -> rSocket.requestChannel(Mono.just(DefaultPayload.create("test"))), - rSocket -> rSocket.metadataPush(DefaultPayload.create("", "test"))); + Arguments.of( + (Function>) + rSocket -> rSocket.fireAndForget(DefaultPayload.create("test")), + FrameType.REQUEST_FNF), + Arguments.of( + (Function>) + rSocket -> rSocket.requestResponse(DefaultPayload.create("test")), + FrameType.REQUEST_RESPONSE), + Arguments.of( + (Function>) + rSocket -> rSocket.requestStream(DefaultPayload.create("test")), + FrameType.REQUEST_STREAM), + Arguments.of( + (Function>) + rSocket -> rSocket.requestChannel(Mono.just(DefaultPayload.create("test"))), + FrameType.REQUEST_CHANNEL), + Arguments.of( + (Function>) + rSocket -> + rSocket.metadataPush( + DefaultPayload.create(new byte[0], "test".getBytes(CharsetUtil.UTF_8))), + FrameType.METADATA_PUSH)); } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java index 45770d375..56e3c5633 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -37,7 +37,6 @@ import io.netty.util.ReferenceCounted; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.TestScheduler; import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.exceptions.CustomRSocketException; import io.rsocket.exceptions.RejectedSetupException; @@ -355,7 +354,8 @@ public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmen t -> Assertions.assertThat(t) .isInstanceOf(IllegalArgumentException.class) - .hasMessage(INVALID_PAYLOAD_ERROR_MESSAGE)) + .hasMessage( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength))) .verify(); rule.assertHasNoLeaks(); }); @@ -370,27 +370,33 @@ static Stream>> prepareCalls() { RSocket::metadataPush); } - @Test + @ParameterizedTest + @ValueSource(ints = {128, 256, FrameLengthCodec.FRAME_LENGTH_MASK}) public void - shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentationForRequestChannelCase() { - byte[] metadata = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; - byte[] data = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentationForRequestChannelCase( + int maxFrameLength) { + rule.setMaxFrameLength(maxFrameLength); + byte[] metadata = new byte[maxFrameLength]; + byte[] data = new byte[maxFrameLength]; ThreadLocalRandom.current().nextBytes(metadata); ThreadLocalRandom.current().nextBytes(data); StepVerifier.create( rule.socket.requestChannel( - Flux.just(EmptyPayload.INSTANCE, DefaultPayload.create(data, metadata)))) + Flux.just(EmptyPayload.INSTANCE, DefaultPayload.create(data, metadata))), + 0) .expectSubscription() + .thenRequest(2) .then( - () -> - rule.connection.addToReceivedBuffer( - RequestNFrameCodec.encode( - rule.alloc(), rule.getStreamIdForRequestType(REQUEST_CHANNEL), 2))) + () -> { + rule.connection.addToReceivedBuffer( + RequestNFrameCodec.encode( + rule.alloc(), rule.getStreamIdForRequestType(REQUEST_CHANNEL), 2)); + }) .expectErrorSatisfies( t -> Assertions.assertThat(t) .isInstanceOf(IllegalArgumentException.class) - .hasMessage(INVALID_PAYLOAD_ERROR_MESSAGE)) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength))) .verify(); Assertions.assertThat(rule.connection.getSent()) // expect to be sent RequestChannelFrame @@ -537,9 +543,10 @@ private static Stream racingCases() { (as, rule) -> { RaceTestUtils.race(() -> as.request(1), as::cancel); // ensures proper frames order - if (rule.connection.getSent().size() > 0) { - // - // Assertions.assertThat(rule.connection.getSent()).hasSize(2); + int size = rule.connection.getSent().size(); + if (size > 0) { + + Assertions.assertThat(size).isLessThanOrEqualTo(3).isGreaterThanOrEqualTo(2); Assertions.assertThat(rule.connection.getSent()) .element(0) .matches( @@ -549,16 +556,43 @@ private static Stream racingCases() { + "} but was {" + frameType(rule.connection.getSent().stream().findFirst().get()) + "}"); - Assertions.assertThat(rule.connection.getSent()) - .element(1) - .matches( - bb -> frameType(bb) == CANCEL, - "Expected first frame matches {" - + CANCEL - + "} but was {" - + frameType( - rule.connection.getSent().stream().skip(1).findFirst().get()) - + "}"); + if (size == 2) { + Assertions.assertThat(rule.connection.getSent()) + .element(1) + .matches( + bb -> frameType(bb) == CANCEL, + "Expected second frame matches {" + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(1).findFirst().get()) + + "}"); + } else { + Assertions.assertThat(rule.connection.getSent()) + .element(1) + .matches( + bb -> frameType(bb) == COMPLETE || frameType(bb) == CANCEL, + "Expected second frame matches {" + + COMPLETE + + " or " + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(1).findFirst().get()) + + "}"); + Assertions.assertThat(rule.connection.getSent()) + .element(2) + .matches( + bb -> frameType(bb) == CANCEL || frameType(bb) == COMPLETE, + "Expected third frame matches {" + + COMPLETE + + " or " + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(2).findFirst().get()) + + "}"); + } } }), Arguments.of( @@ -813,25 +847,36 @@ static Stream encodeDecodePayloadCases() { @ParameterizedTest @MethodSource("refCntCases") public void ensureSendsErrorOnIllegalRefCntPayload( - BiFunction> sourceProducer) { + BiFunction> sourceProducer) { Payload invalidPayload = ByteBufPayload.create("test", "test"); invalidPayload.release(); - Publisher source = sourceProducer.apply(invalidPayload, rule.socket); + Publisher source = sourceProducer.apply(invalidPayload, rule); - StepVerifier.create(source, 0) + StepVerifier.create(source, 1) .expectError(IllegalReferenceCountException.class) - .verify(Duration.ofMillis(100)); + .verify(Duration.ofMillis(1000)); } - private static Stream>> refCntCases() { + private static Stream>> refCntCases() { return Stream.of( - (p, r) -> r.fireAndForget(p), - (p, r) -> r.requestResponse(p), - (p, r) -> r.requestStream(p), - (p, r) -> r.requestChannel(Mono.just(p)), - (p, r) -> - r.requestChannel(Flux.just(EmptyPayload.INSTANCE, p).doOnSubscribe(s -> s.request(1)))); + (p, clientSocketRule) -> clientSocketRule.socket.fireAndForget(p), + (p, clientSocketRule) -> clientSocketRule.socket.requestResponse(p), + (p, clientSocketRule) -> clientSocketRule.socket.requestStream(p), + (p, clientSocketRule) -> clientSocketRule.socket.requestChannel(Mono.just(p)), + (p, clientSocketRule) -> { + Flux.from(clientSocketRule.connection.getSentAsPublisher()) + .filter(bb -> FrameHeaderCodec.frameType(bb) == REQUEST_CHANNEL) + .subscribe( + bb -> { + clientSocketRule.connection.addToReceivedBuffer( + RequestNFrameCodec.encode( + clientSocketRule.allocator, FrameHeaderCodec.streamId(bb), 1)); + bb.release(); + }); + + return clientSocketRule.socket.requestChannel(Flux.just(EmptyPayload.INSTANCE, p)); + }); } @Test @@ -904,7 +949,7 @@ public void ensuresThatNoOpsMustHappenUntilFirstRequestN( assertSubscriber2.request(1); Assertions.assertThat(rule.connection.getSent()) - .hasSize(1) + .hasSize(frameType == REQUEST_CHANNEL ? 2 : 1) .first() .matches(bb -> frameType(bb) == frameType) .matches( @@ -931,11 +976,23 @@ public void ensuresThatNoOpsMustHappenUntilFirstRequestN( }) .matches(ReferenceCounted::release); + if (frameType == REQUEST_CHANNEL) { + Assertions.assertThat(rule.connection.getSent()) + .element(1) + .matches(bb -> frameType(bb) == COMPLETE) + .matches( + bb -> FrameHeaderCodec.streamId(bb) == 1, + "Expected to have stream ID {1} but got {" + + FrameHeaderCodec.streamId(new ArrayList<>(rule.connection.getSent()).get(1)) + + "}") + .matches(ReferenceCounted::release); + } + rule.connection.clearSendReceiveBuffers(); assertSubscriber1.request(1); Assertions.assertThat(rule.connection.getSent()) - .hasSize(1) + .hasSize(frameType == REQUEST_CHANNEL ? 2 : 1) .first() .matches(bb -> frameType(bb) == frameType) .matches( @@ -961,6 +1018,18 @@ public void ensuresThatNoOpsMustHappenUntilFirstRequestN( return false; }) .matches(ReferenceCounted::release); + + if (frameType == REQUEST_CHANNEL) { + Assertions.assertThat(rule.connection.getSent()) + .element(1) + .matches(bb -> frameType(bb) == COMPLETE) + .matches( + bb -> FrameHeaderCodec.streamId(bb) == 3, + "Expected to have stream ID {1} but got {" + + FrameHeaderCodec.streamId(new ArrayList<>(rule.connection.getSent()).get(1)) + + "}") + .matches(ReferenceCounted::release); + } } private static Stream requestNInteractions() { @@ -981,6 +1050,7 @@ private static Stream requestNInteractions() { @ParameterizedTest @MethodSource("streamRacingCases") + @Disabled("Connection should take care of ordering if such is necessary") public void ensuresCorrectOrderOfStreamIdIssuingInCaseOfRacing( BiFunction> interaction1, BiFunction> interaction2, @@ -1090,7 +1160,7 @@ public void shouldTerminateAllStreamsIfThereRacingBetweenDisposeAndRequests( Schedulers.parallel()); assertSubscriber1.await().assertTerminated(); - if (interactionType1 != REQUEST_FNF) { + if (interactionType1 != REQUEST_FNF && interactionType1 != METADATA_PUSH) { assertSubscriber1.assertError(ClosedChannelException.class); } else { try { @@ -1101,7 +1171,7 @@ public void shouldTerminateAllStreamsIfThereRacingBetweenDisposeAndRequests( } } assertSubscriber2.await().assertTerminated(); - if (interactionType2 != REQUEST_FNF) { + if (interactionType2 != REQUEST_FNF && interactionType2 != METADATA_PUSH) { assertSubscriber2.assertError(ClosedChannelException.class); } else { try { @@ -1121,6 +1191,7 @@ public void shouldTerminateAllStreamsIfThereRacingBetweenDisposeAndRequests( } @Test + @Disabled("Reactor 3.4.0 should fix that. No need to do anything on our side") // see https://github.com/rsocket/rsocket-java/issues/858 public void testWorkaround858() { ByteBuf buffer = rule.alloc().buffer(); @@ -1153,9 +1224,9 @@ protected RSocketRequester newRSocket() { maxFrameLength, Integer.MAX_VALUE, Integer.MAX_VALUE, + Integer.MAX_VALUE, null, - RequesterLeaseHandler.None, - TestScheduler.INSTANCE); + RequesterLeaseHandler.None); } public int getStreamIdForRequestType(FrameType expectedFrameType) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java index 0d0fbd8c0..8b6518478 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java @@ -237,7 +237,10 @@ protected void hookOnSubscribe(Subscription subscription) { .hasSize(1) .first() .matches(bb -> FrameHeaderCodec.frameType(bb) == FrameType.ERROR) - .matches(bb -> ErrorFrameCodec.dataUtf8(bb).contains(INVALID_PAYLOAD_ERROR_MESSAGE)) + .matches( + bb -> + ErrorFrameCodec.dataUtf8(bb) + .contains(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength))) .matches(ReferenceCounted::release); assertThat("Subscription not cancelled.", cancelled.get(), is(true)); @@ -471,7 +474,7 @@ public Flux requestChannel(Publisher payloads) { assertSubscriber .assertTerminated() .assertError(CancellationException.class) - .assertErrorMessage("Disposed"); + .assertErrorMessage("Outbound has terminated with an error"); Assertions.assertThat(assertSubscriber.values()) .allMatch( msg -> { @@ -783,6 +786,7 @@ private static Stream refCntCases() { } @Test + @Disabled("Reactor 3.4.0 should fix that. No need to do anything on our side") // see https://github.com/rsocket/rsocket-java/issues/858 public void testWorkaround858() { ByteBuf buffer = rule.alloc().buffer(); @@ -859,7 +863,8 @@ protected RSocketResponder newRSocket() { PayloadDecoder.ZERO_COPY, ResponderLeaseHandler.None, 0, - maxFrameLength); + maxFrameLength, + Integer.MAX_VALUE); } private void sendRequest(int streamId, FrameType frameType) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java index 1e7bb337f..38745327e 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java @@ -22,7 +22,6 @@ import io.netty.buffer.ByteBufAllocator; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.TestScheduler; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.exceptions.CustomRSocketException; @@ -163,7 +162,7 @@ public void testStream() throws Exception { StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); } - @Test(timeout = 2000) + @Test(timeout = 200000) public void testChannel() throws Exception { Flux requests = Flux.range(0, 10).map(i -> DefaultPayload.create("streaming in -> " + i)); @@ -569,7 +568,8 @@ public Flux requestChannel(Publisher payloads) { PayloadDecoder.DEFAULT, ResponderLeaseHandler.None, 0, - FRAME_LENGTH_MASK); + FRAME_LENGTH_MASK, + Integer.MAX_VALUE); crs = new RSocketRequester( @@ -578,11 +578,11 @@ public Flux requestChannel(Publisher payloads) { StreamIdSupplier.clientSupplier(), 0, FRAME_LENGTH_MASK, + Integer.MAX_VALUE, 0, 0, null, - RequesterLeaseHandler.None, - TestScheduler.INSTANCE); + RequesterLeaseHandler.None); } public void setRequestAcceptor(RSocket requestAcceptor) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestChannelRequesterFluxTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelRequesterFluxTest.java new file mode 100644 index 000000000..4fc06fdc2 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelRequesterFluxTest.java @@ -0,0 +1,756 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Signal; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; + +public class RequestChannelRequesterFluxTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /* + * +-------------------------------+ + * | General Test Cases | + * +-------------------------------+ + */ + @ParameterizedTest + @ValueSource(strings = {"inbound", "outbound"}) + public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String completionCase) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(10); + + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + + stateAssert.hasSubscribedFlag().hasRequestN(10).hasNoFirstFrameSentFlag(); + + publisher.assertMaxRequested(1).next(payload); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(10).hasFirstFrameSentFlag(); + + final ByteBuf frame = sender.poll(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(10) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + final ByteBuf requestNFrame = sender.poll(); + FrameAssert.assertThat(requestNFrame) + .isNotNull() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check. Request N Frame should sent so request field should be 0 + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(11).hasFirstFrameSentFlag(); + + assertSubscriber.request(Long.MAX_VALUE); + final ByteBuf requestMaxNFrame = sender.poll(); + FrameAssert.assertThat(requestMaxNFrame) + .isNotNull() + .hasRequestN(Integer.MAX_VALUE) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + assertSubscriber.request(6); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + Payload nextPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelRequesterFlux.handlePayload(nextPayload); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, randomPayload); + + ByteBuf firstFragment = fragments.remove(0); + requestChannelRequesterFlux.handleNext(firstFragment, true, false); + firstFragment.release(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag(); + + for (int i = 0; i < fragments.size(); i++) { + boolean hasFollows = i != fragments.size() - 1; + ByteBuf followingFragment = fragments.get(i); + + requestChannelRequesterFlux.handleNext(followingFragment, hasFollows, false); + followingFragment.release(); + } + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag(); + + if (completionCase.equals("inbound")) { + requestChannelRequesterFlux.handleComplete(); + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isEqualTo(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .assertComplete(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasInboundTerminated(); + + publisher.complete(); + FrameAssert.assertThat(sender.poll()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + } else if (completionCase.equals("outbound")) { + publisher.complete(); + FrameAssert.assertThat(sender.poll()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasOutboundTerminated(); + + requestChannelRequesterFlux.handleComplete(); + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isEqualTo(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .assertComplete(); + } + + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void streamShouldErrorWithoutInitializingRemoteStreamIfSourceIsEmpty(boolean doRequest) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + if (doRequest) { + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + } + + publisher.complete(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + assertSubscriber + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Empty Source"); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void streamShouldPropagateErrorWithoutInitializingRemoteStreamIfTheFirstSignalIsError( + boolean doRequest) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + if (doRequest) { + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + } + + publisher.error(new RuntimeException("test")); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + assertSubscriber + .assertTerminated() + .assertError(RuntimeException.class) + .assertErrorMessage("test"); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(strings = {"inbound", "outbound"}) + public void streamShouldBeInHalfClosedStateOnTheInboundCancellation(String terminationMode) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload payload1 = TestRequesterResponderSupport.randomPayload(allocator); + Payload payload2 = TestRequesterResponderSupport.randomPayload(allocator); + Payload payload3 = TestRequesterResponderSupport.randomPayload(allocator); + + publisher.next(payload1.retain()); + + FrameAssert.assertThat(sender.poll()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasPayload(payload1) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + payload1.release(); + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + publisher.assertMaxRequested(1); + + requestChannelRequesterFlux.handleRequestN(10); + publisher.assertMaxRequested(10); + + requestChannelRequesterFlux.handleRequestN(Long.MAX_VALUE); + publisher.assertMaxRequested(Long.MAX_VALUE); + + publisher.next(payload2.retain(), payload3.retain()); + + FrameAssert.assertThat(sender.poll()).typeOf(FrameType.NEXT).hasPayload(payload2).hasNoLeaks(); + payload2.release(); + + FrameAssert.assertThat(sender.poll()).typeOf(FrameType.NEXT).hasPayload(payload3).hasNoLeaks(); + payload3.release(); + + if (terminationMode.equals("outbound")) { + requestChannelRequesterFlux.handleCancel(); + + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasOutboundTerminated(); + + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + requestChannelRequesterFlux.handleComplete(); + } else if (terminationMode.equals("inbound")) { + requestChannelRequesterFlux.handleComplete(); + + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasInboundTerminated(); + + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + requestChannelRequesterFlux.handleCancel(); + } + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + } + + @ParameterizedTest + @ValueSource(strings = {"inbound", "outbound"}) + public void errorShouldTerminateExecution(String terminationMode) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload payload1 = TestRequesterResponderSupport.randomPayload(allocator); + Payload payload2 = TestRequesterResponderSupport.randomPayload(allocator); + Payload payload3 = TestRequesterResponderSupport.randomPayload(allocator); + + publisher.next(payload1.retain()); + + FrameAssert.assertThat(sender.poll()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasPayload(payload1) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + payload1.release(); + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + publisher.assertMaxRequested(1); + + requestChannelRequesterFlux.handleRequestN(10); + publisher.assertMaxRequested(10); + + requestChannelRequesterFlux.handleRequestN(Long.MAX_VALUE); + publisher.assertMaxRequested(Long.MAX_VALUE); + + publisher.next(payload2.retain(), payload3.retain()); + + FrameAssert.assertThat(sender.poll()).typeOf(FrameType.NEXT).hasPayload(payload2).hasNoLeaks(); + payload2.release(); + + FrameAssert.assertThat(sender.poll()).typeOf(FrameType.NEXT).hasPayload(payload3).hasNoLeaks(); + payload3.release(); + + if (terminationMode.equals("outbound")) { + publisher.error(new ApplicationErrorException("test")); + FrameAssert.assertThat(sender.poll()).typeOf(FrameType.ERROR).hasData("test").hasNoLeaks(); + } else if (terminationMode.equals("inbound")) { + requestChannelRequesterFlux.handleError(new ApplicationErrorException("test")); + publisher.assertWasCancelled(); + } + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + } + + /* + * +--------------------------------+ + * | Racing Test Cases | + * +--------------------------------+ + */ + + static Stream cases() { + return Stream.of( + Arguments.arguments("complete", "sizeError"), + Arguments.arguments("complete", "refCntError"), + Arguments.arguments("complete", "onError"), + Arguments.arguments("error", "sizeError"), + Arguments.arguments("error", "refCntError"), + Arguments.arguments("error", "onError"), + Arguments.arguments("cancel", "sizeError"), + Arguments.arguments("cancel", "refCntError"), + Arguments.arguments("cancel", "onError")); + } + + @ParameterizedTest + @MethodSource("cases") + public void shouldHaveEventsDeliveredSeriallyWhenOutboundErrorRacingWithInboundSignals( + String inboundTerminationMode, String outboundTerminationMode) { + final RuntimeException outboundException = new RuntimeException("outboundException"); + final ApplicationErrorException inboundException = + new ApplicationErrorException("inboundException"); + + final ArrayList droppedErrors = new ArrayList<>(); + final Payload oversizePayload = + DefaultPayload.create(new byte[FRAME_LENGTH_MASK], new byte[FRAME_LENGTH_MASK]); + + Hooks.onErrorDropped(droppedErrors::add); + try { + for (int i = 0; i < 10000; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.DEFER_CANCELLATION); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber> assertSubscriber = + requestChannelRequesterFlux.materialize().subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload requestPayload = TestRequesterResponderSupport.randomPayload(allocator); + publisher.next(requestPayload); + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + FrameAssert.assertThat(sender.poll()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + + requestChannelRequesterFlux.handleRequestN(Long.MAX_VALUE); + + Payload responsePayload1 = TestRequesterResponderSupport.randomPayload(allocator); + Payload responsePayload2 = TestRequesterResponderSupport.randomPayload(allocator); + Payload responsePayload3 = TestRequesterResponderSupport.randomPayload(allocator); + + Payload releasedPayload = ByteBufPayload.create(Unpooled.EMPTY_BUFFER); + releasedPayload.release(); + + RaceTestUtils.race( + () -> { + if (outboundTerminationMode.equals("onError")) { + publisher.error(outboundException); + } else if (outboundTerminationMode.equals("refCntError")) { + publisher.next(releasedPayload); + } else { + publisher.next(oversizePayload); + } + }, + () -> { + requestChannelRequesterFlux.handlePayload(responsePayload1); + requestChannelRequesterFlux.handlePayload(responsePayload2); + requestChannelRequesterFlux.handlePayload(responsePayload3); + + if (inboundTerminationMode.equals("error")) { + requestChannelRequesterFlux.handleError(inboundException); + } else if (inboundTerminationMode.equals("complete")) { + requestChannelRequesterFlux.handleComplete(); + } else { + requestChannelRequesterFlux.handleCancel(); + } + }); + + ByteBuf errorFrameOrEmpty = sender.poll(); + if (errorFrameOrEmpty != null) { + if (outboundTerminationMode.equals("onError")) { + FrameAssert.assertThat(errorFrameOrEmpty) + .typeOf(FrameType.ERROR) + .hasData("outboundException") + .hasNoLeaks(); + } else { + FrameAssert.assertThat(errorFrameOrEmpty).typeOf(FrameType.CANCEL).hasNoLeaks(); + } + } + + List> values = assertSubscriber.values(); + for (int j = 0; j < values.size(); j++) { + Signal signal = values.get(j); + + if (signal.isOnNext()) { + PayloadAssert.assertThat(signal.get()) + .describedAs("Expected that the next signal[%s] to have no leaks", j) + .hasNoLeaks(); + } else { + if (inboundTerminationMode.equals("error")) { + Assertions.assertThat(signal.isOnError()).isTrue(); + Throwable throwable = signal.getThrowable(); + if (throwable == inboundException) { + Assertions.assertThat(droppedErrors.get(0)) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + Assertions.assertThat(throwable).isEqualTo(inboundException); + } else { + Assertions.assertThat(droppedErrors).containsOnly(inboundException); + Assertions.assertThat(throwable) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } + } else if (inboundTerminationMode.equals("complete")) { + if (signal.isOnComplete()) { + Assertions.assertThat(droppedErrors.get(0)) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } else { + Assertions.assertThat(droppedErrors).isEmpty(); + Assertions.assertThat(signal.getThrowable()) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } + } else { + Assertions.assertThat(signal.getThrowable()) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } + + Assertions.assertThat(j) + .describedAs( + "Expected that the error signal[%s] is the last signal, but the last was %s", + j, values.size() - 1) + .isEqualTo(values.size() - 1); + } + } + + allocator.assertHasNoLeaks(); + droppedErrors.clear(); + } + } finally { + Hooks.resetOnErrorDropped(); + } + } + + @ParameterizedTest + @ValueSource(strings = {"complete", "cancel"}) + public void shouldRemoveItselfFromActiveStreamsWhenInboundAndOutboundAreTerminated( + String outboundTerminationMode) { + for (int i = 0; i < 10000; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.DEFER_CANCELLATION); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber> assertSubscriber = + requestChannelRequesterFlux.materialize().subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload requestPayload = TestRequesterResponderSupport.randomPayload(allocator); + publisher.next(requestPayload); + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + FrameAssert.assertThat(sender.poll()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + + requestChannelRequesterFlux.handleRequestN(Long.MAX_VALUE); + + RaceTestUtils.race( + () -> { + if (outboundTerminationMode.equals("cancel")) { + requestChannelRequesterFlux.handleCancel(); + } else { + publisher.complete(); + } + }, + requestChannelRequesterFlux::handleComplete); + + ByteBuf completeFrameOrNull = sender.poll(); + if (completeFrameOrNull != null) { + FrameAssert.assertThat(completeFrameOrNull) + .hasStreamId(1) + .typeOf(FrameType.COMPLETE) + .hasNoLeaks(); + } + + assertSubscriber.assertTerminated().assertComplete(); + activeStreams.assertNoActiveStreams(); + allocator.assertHasNoLeaks(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java new file mode 100644 index 000000000..b1c1e8cf9 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java @@ -0,0 +1,688 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameType.*; +import static reactor.test.publisher.TestPublisher.Violation.CLEANUP_ON_TERMINATE; +import static reactor.test.publisher.TestPublisher.Violation.DEFER_CANCELLATION; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.Exceptions; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Signal; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; + +public class RequestChannelResponderSubscriberTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /* + * +-------------------------------+ + * | General Test Cases | + * +-------------------------------+ + */ + @ParameterizedTest + @ValueSource(strings = {"inbound", "outbound", "inboundCancel"}) + public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String completionCase) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + publisher.assertMaxRequested(1); + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + + final AssertSubscriber assertSubscriber = + requestChannelResponderSubscriber.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(firstPayload.refCnt()).isOne(); + + // state machine check + stateAssert.hasSubscribedFlagOnly().hasRequestN(0); + + assertSubscriber.request(1); + + // state machine check + stateAssert.hasSubscribedFlag().hasFirstFrameSentFlag().hasRequestN(1); + + // should not send requestN since 1 is remaining + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + + stateAssert.hasSubscribedFlag().hasRequestN(2).hasFirstFrameSentFlag(); + + // should not send requestN since 1 is remaining + FrameAssert.assertThat(sender.poll()) + .typeOf(REQUEST_N) + .hasStreamId(1) + .hasRequestN(1) + .hasNoLeaks(); + + publisher.next(TestRequesterResponderSupport.genericPayload(allocator)); + + final ByteBuf frame = sender.poll(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .typeOf(NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber.request(Long.MAX_VALUE); + final ByteBuf requestMaxNFrame = sender.poll(); + FrameAssert.assertThat(requestMaxNFrame) + .isNotNull() + .hasRequestN(Integer.MAX_VALUE) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + Payload nextPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelResponderSubscriber.handlePayload(nextPayload); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, randomPayload); + + ByteBuf firstFragment = fragments.remove(0); + requestChannelResponderSubscriber.handleNext(firstFragment, true, false); + firstFragment.release(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag(); + + for (int i = 0; i < fragments.size(); i++) { + boolean hasFollows = i != fragments.size() - 1; + ByteBuf followingFragment = fragments.get(i); + + requestChannelResponderSubscriber.handleNext(followingFragment, hasFollows, false); + followingFragment.release(); + } + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag(); + + if (completionCase.equals("inbound")) { + requestChannelResponderSubscriber.handleComplete(); + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isSameAs(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .assertComplete(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasInboundTerminated(); + + publisher.complete(); + FrameAssert.assertThat(sender.poll()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + } else if (completionCase.equals("inboundCancel")) { + assertSubscriber.cancel(); + assertSubscriber.assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isSameAs(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }); + + FrameAssert.assertThat(sender.poll()).typeOf(CANCEL).hasStreamId(1).hasNoLeaks(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasInboundTerminated(); + + publisher.complete(); + FrameAssert.assertThat(sender.poll()).typeOf(FrameType.COMPLETE).hasStreamId(1).hasNoLeaks(); + } else if (completionCase.equals("outbound")) { + publisher.complete(); + FrameAssert.assertThat(sender.poll()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasOutboundTerminated(); + + requestChannelResponderSubscriber.handleComplete(); + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(p).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isEqualTo(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .assertComplete(); + } + + Assertions.assertThat(firstPayload.refCnt()).isZero(); + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + /* + * +--------------------------------+ + * | Racing Test Cases | + * +--------------------------------+ + */ + + @Test + public void streamShouldWorkCorrectlyWhenRacingHandleCompleteWithSubscription() { + for (int i = 0; i < 10000; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + + RaceTestUtils.race( + () -> + requestChannelResponderSubscriber + .doOnNext(__ -> assertSubscriber.request(1)) + .subscribe(assertSubscriber), + () -> requestChannelResponderSubscriber.handleComplete()); + + stateAssert + .hasSubscribedFlag() + .hasInboundTerminated() + .hasFirstFrameSentFlag() + .hasRequestNBetween(1, 2); + + assertSubscriber + .assertValuesWith(p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks()) + .assertTerminated() + .assertComplete(); + + publisher.complete(); + + if (sender.size() > 1) { + FrameAssert.assertThat(sender.poll()) + .hasStreamId(1) + .typeOf(REQUEST_N) + .hasRequestN(1) + .hasNoLeaks(); + } + FrameAssert.assertThat(sender.poll()).hasStreamId(1).typeOf(COMPLETE).hasNoLeaks(); + + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + } + + @Test + public void streamShouldWorkCorrectlyWhenRacingHandleErrorWithSubscription() { + ApplicationErrorException applicationErrorException = new ApplicationErrorException("test"); + + for (int i = 0; i < 10000; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + + RaceTestUtils.race( + () -> requestChannelResponderSubscriber.subscribe(assertSubscriber), + () -> requestChannelResponderSubscriber.handleError(applicationErrorException)); + + stateAssert.isTerminated(); + + publisher.assertCancelled(1); + + if (!assertSubscriber.values().isEmpty()) { + assertSubscriber.assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(p).hasNoLeaks()); + } + + assertSubscriber + .assertTerminated() + .assertError(applicationErrorException.getClass()) + .assertErrorMessage("test"); + + allocator.assertHasNoLeaks(); + } + } + + @Test + public void streamShouldWorkCorrectlyWhenRacingHandleCancelWithSubscription() { + for (int i = 0; i < 10000; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + + RaceTestUtils.race( + () -> requestChannelResponderSubscriber.subscribe(assertSubscriber), + () -> requestChannelResponderSubscriber.handleCancel()); + + stateAssert.isTerminated(); + + publisher.assertCancelled(1); + + if (!assertSubscriber.values().isEmpty()) { + assertSubscriber.assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(p).hasNoLeaks()); + } + + assertSubscriber + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Inbound has been canceled"); + + allocator.assertHasNoLeaks(); + } + } + + static Stream cases() { + return Stream.of( + Arguments.arguments("complete", "sizeError"), + Arguments.arguments("complete", "refCntError"), + Arguments.arguments("complete", "onError"), + Arguments.arguments("error", "sizeError"), + Arguments.arguments("error", "refCntError"), + Arguments.arguments("error", "onError"), + Arguments.arguments("cancel", "sizeError"), + Arguments.arguments("cancel", "refCntError"), + Arguments.arguments("cancel", "onError")); + } + + @ParameterizedTest + @MethodSource("cases") + public void shouldHaveEventsDeliveredSeriallyWhenOutboundErrorRacingWithInboundSignals( + String inboundTerminationMode, String outboundTerminationMode) { + final RuntimeException outboundException = new RuntimeException("outboundException"); + final ApplicationErrorException inboundException = + new ApplicationErrorException("inboundException"); + final ArrayList droppedErrors = new ArrayList<>(); + final Payload oversizePayload = + DefaultPayload.create(new byte[FRAME_LENGTH_MASK], new byte[FRAME_LENGTH_MASK]); + + Hooks.onErrorDropped(droppedErrors::add); + try { + for (int i = 0; i < 10000; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestPublisher publisher = + TestPublisher.createNoncompliant(DEFER_CANCELLATION, CLEANUP_ON_TERMINATE); + + Payload requestPayload = TestRequesterResponderSupport.randomPayload(allocator); + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, requestPayload, activeStreams); + + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + final AssertSubscriber> assertSubscriber = + requestChannelResponderSubscriber + .materialize() + .subscribeWith(AssertSubscriber.create(0)); + + assertSubscriber.request(Integer.MAX_VALUE); + + FrameAssert.assertThat(sender.poll()) + .typeOf(FrameType.REQUEST_N) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + + requestChannelResponderSubscriber.handleRequestN(Long.MAX_VALUE); + + Payload responsePayload1 = TestRequesterResponderSupport.randomPayload(allocator); + Payload responsePayload2 = TestRequesterResponderSupport.randomPayload(allocator); + Payload responsePayload3 = TestRequesterResponderSupport.randomPayload(allocator); + + Payload releasedPayload = ByteBufPayload.create(Unpooled.EMPTY_BUFFER); + releasedPayload.release(); + + RaceTestUtils.race( + () -> { + if (outboundTerminationMode.equals("onError")) { + publisher.error(outboundException); + } else if (outboundTerminationMode.equals("refCntError")) { + publisher.next(releasedPayload); + } else { + publisher.next(oversizePayload); + } + }, + () -> { + requestChannelResponderSubscriber.handlePayload(responsePayload1); + requestChannelResponderSubscriber.handlePayload(responsePayload2); + requestChannelResponderSubscriber.handlePayload(responsePayload3); + + if (inboundTerminationMode.equals("error")) { + requestChannelResponderSubscriber.handleError(inboundException); + } else if (inboundTerminationMode.equals("complete")) { + requestChannelResponderSubscriber.handleComplete(); + } else { + requestChannelResponderSubscriber.handleCancel(); + } + }); + + ByteBuf errorFrameOrEmpty = sender.poll(); + if (errorFrameOrEmpty != null) { + String message; + if (outboundTerminationMode.equals("onError")) { + message = outboundException.getMessage(); + } else if (outboundTerminationMode.equals("sizeError")) { + message = String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK); + } else { + message = "Failed to validate payload. Cause:refCnt: 0"; + } + FrameAssert.assertThat(errorFrameOrEmpty) + .typeOf(FrameType.ERROR) + .hasData(message) + .hasNoLeaks(); + } + + List> values = assertSubscriber.values(); + for (int j = 0; j < values.size(); j++) { + Signal signal = values.get(j); + + if (signal.isOnNext()) { + Payload payload = signal.get(); + if (j == 0) { + Assertions.assertThat(payload).isEqualTo(requestPayload); + } + + PayloadAssert.assertThat(payload) + .describedAs("Expected that the next signal[%s] to have no leaks", j) + .hasNoLeaks(); + } else { + if (inboundTerminationMode.equals("error")) { + Assertions.assertThat(signal.isOnError()).isTrue(); + Throwable throwable = signal.getThrowable(); + if (Exceptions.isMultiple(throwable)) { + Assertions.assertThat( + Arrays.stream(throwable.getSuppressed()).map(Throwable::getMessage)) + .containsExactlyInAnyOrder( + inboundException.getMessage(), + outboundTerminationMode.equals("onError") + ? "Outbound has terminated with an error" + : "Inbound has been canceled"); + } else { + if (throwable == inboundException) { + Assertions.assertThat(droppedErrors) + .hasSize(1) + .first() + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } else { + Assertions.assertThat(droppedErrors).containsOnly(inboundException); + } + } + } else if (inboundTerminationMode.equals("complete")) { + Assertions.assertThat(droppedErrors).isEmpty(); + if (signal.isOnError()) { + Assertions.assertThat(signal.getThrowable()) + .isExactlyInstanceOf(CancellationException.class) + .matches( + t -> + t.getMessage().equals("Inbound has been canceled") + || t.getMessage().equals("Outbound has terminated with an error")); + } + } else { + Throwable throwable = signal.getThrowable(); + if (Exceptions.isMultiple(throwable)) { + Assertions.assertThat( + Arrays.stream(throwable.getSuppressed()).map(Throwable::getMessage)) + .containsExactlyInAnyOrder( + "Inbound has been canceled", + outboundTerminationMode.equals("onError") + ? "Outbound has terminated with an error" + : "Inbound has been canceled"); + } else { + Assertions.assertThat(throwable).isExactlyInstanceOf(CancellationException.class); + } + } + + Assertions.assertThat(j) + .describedAs( + "Expected that the %s signal[%s] is the last signal, but the last was %s", + signal, j, values.get(values.size() - 1)) + .isEqualTo(values.size() - 1); + } + } + + allocator.assertHasNoLeaks(); + droppedErrors.clear(); + } + } finally { + Hooks.resetOnErrorDropped(); + } + } + + @ParameterizedTest + @ValueSource(strings = {"onError", "sizeError", "refCntError", "cancel"}) + public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(String terminationMode) { + final RuntimeException outboundException = new RuntimeException("outboundException"); + final Payload oversizePayload = + DefaultPayload.create(new byte[FRAME_LENGTH_MASK], new byte[FRAME_LENGTH_MASK]); + + for (int i = 0; i < 10000; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestPublisher publisher = + TestPublisher.createNoncompliant(DEFER_CANCELLATION, CLEANUP_ON_TERMINATE); + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(1); + + Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); + final RequestChannelResponderSubscriber requestOperator = + new RequestChannelResponderSubscriber(1, Long.MAX_VALUE, firstPayload, activeStreams); + + publisher.subscribe(requestOperator); + requestOperator.subscribe(assertSubscriber); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload responsePayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, responsePayload); + + Payload releasedPayload1 = ByteBufPayload.create(new byte[0]); + Payload releasedPayload2 = ByteBufPayload.create(new byte[0]); + releasedPayload1.release(); + releasedPayload2.release(); + + RaceTestUtils.race( + () -> { + switch (terminationMode) { + case "onError": + publisher.error(outboundException); + break; + case "sizeError": + publisher.next(oversizePayload); + break; + case "refCntError": + publisher.next(releasedPayload1); + break; + case "cancel": + default: + assertSubscriber.cancel(); + } + }, + () -> { + int lastFragmentId = fragments.size() - 1; + for (int j = 0; j < fragments.size(); j++) { + ByteBuf frame = fragments.get(j); + requestOperator.handleNext(frame, lastFragmentId != j, false); + frame.release(); + } + }); + + List values = assertSubscriber.values(); + + PayloadAssert.assertThat(values.get(0)).isEqualTo(firstPayload).hasNoLeaks(); + + if (values.size() > 1) { + Payload payload = values.get(1); + PayloadAssert.assertThat(payload).isEqualTo(responsePayload).hasNoLeaks(); + } + + if (!sender.isEmpty()) { + if (terminationMode.equals("cancel")) { + assertSubscriber.assertNotTerminated(); + } else { + assertSubscriber.assertTerminated().assertError(); + } + + final ByteBuf frame = sender.poll(); + FrameAssert.assertThat(frame) + .isNotNull() + .typeOf(terminationMode.equals("cancel") ? CANCEL : ERROR) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + + PayloadAssert.assertThat(responsePayload).hasNoLeaks(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestResponseRequesterMonoTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestResponseRequesterMonoTest.java new file mode 100644 index 000000000..86babe671 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestResponseRequesterMonoTest.java @@ -0,0 +1,695 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.TestRequesterResponderSupport.genericPayload; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.Arrays; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.Scannable; +import reactor.test.StepVerifier; + +public class RequestResponseRequesterMonoTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /* + * +-------------------------------+ + * | General Test Cases | + * +-------------------------------+ + * + */ + + /** + * General StateMachine transition test. No Fragmentation enabled In this test we check that the + * given instance of RequestResponseMono: 1) subscribes 2) sends frame on the first request 3) + * terminates up on receiving the first signal (terminates on first next | error | next over + * reassembly | complete) + */ + @ParameterizedTest + @MethodSource("frameShouldBeSentOnSubscriptionResponses") + public void frameShouldBeSentOnSubscription( + BiFunction, StepVerifier> + transformer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final Payload payload = genericPayload(allocator); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + + final StateAssert stateAssert = + StateAssert.assertThat(RequestResponseRequesterMono.STATE, requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + transformer + .apply( + requestResponseRequesterMono, + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(stateAssert::hasSubscribedFlagOnly) + .then(() -> Assertions.assertThat(payload.refCnt()).isOne()) + .then(activeStreams::assertNoActiveStreams) + .thenRequest(1) + .then(() -> stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag()) + .then(() -> Assertions.assertThat(payload.refCnt()).isZero()) + .then(() -> activeStreams.assertHasStream(1, requestResponseRequesterMono))) + .verify(); + + PayloadAssert.assertThat(payload).isReleased(); + // should not add anything to map + activeStreams.assertNoActiveStreams(); + + final ByteBuf frame = sender.poll(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .typeOf(FrameType.REQUEST_RESPONSE) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + stateAssert.isTerminated(); + + if (!sender.isEmpty()) { + ByteBuf cancelFrame = sender.poll(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + static Stream, StepVerifier>> + frameShouldBeSentOnSubscriptionResponses() { + return Stream.of( + // next case + (rrm, sv) -> + sv.then(() -> rrm.handlePayload(EmptyPayload.INSTANCE)) + .expectNext(EmptyPayload.INSTANCE) + .expectComplete(), + // complete case + (rrm, sv) -> sv.then(rrm::handleComplete).expectComplete(), + // error case + (rrm, sv) -> + sv.then(() -> rrm.handleError(new ApplicationErrorException("test"))) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(ApplicationErrorException.class)), + // fragmentation case + (rrm, sv) -> { + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + StateAssert stateAssert = StateAssert.assertThat(rrm); + + return sv.then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFirstFragment( + rrm.allocator, + 64, + FrameType.REQUEST_RESPONSE, + 1, + payload.hasMetadata(), + payload.metadata(), + payload.data()); + rrm.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()); + rrm.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()); + rrm.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()); + rrm.handleNext(followingFrame, false, false); + followingFrame.release(); + }) + .then(stateAssert::isTerminated) + .assertNext( + p -> { + Assertions.assertThat(p.data()).isEqualTo(Unpooled.wrappedBuffer(data)); + + Assertions.assertThat(p.metadata()).isEqualTo(Unpooled.wrappedBuffer(metadata)); + p.release(); + }) + .then(payload::release) + .expectComplete(); + }, + (rrm, sv) -> { + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + StateAssert stateAssert = StateAssert.assertThat(rrm); + + ByteBuf[] fragments = + new ByteBuf[] { + FragmentationUtils.encodeFirstFragment( + rrm.allocator, + 64, + FrameType.REQUEST_RESPONSE, + 1, + payload.hasMetadata(), + payload.metadata(), + payload.data()), + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()), + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()) + }; + + final StepVerifier stepVerifier = + sv.then( + () -> { + rrm.handleNext(fragments[0], true, false); + fragments[0].release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + rrm.handleNext(fragments[1], true, false); + fragments[1].release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + rrm.handleNext(fragments[2], true, false); + fragments[2].release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then(payload::release) + .thenCancel() + .verifyLater(); + + stepVerifier.verify(); + + Assertions.assertThat(fragments).allMatch(bb -> bb.refCnt() == 0); + + return stepVerifier; + }); + } + + /** + * General StateMachine transition test. Fragmentation enabled In this test we check that the + * given instance of RequestResponseMono: 1) subscribes 2) sends fragments frames on the first + * request 3) terminates up on receiving the first signal (terminates on first next | error | next + * over reassembly | complete) + */ + @ParameterizedTest + @MethodSource("frameShouldBeSentOnSubscriptionResponses") + public void frameFragmentsShouldBeSentOnSubscription( + BiFunction, StepVerifier> + transformer) { + final int mtu = 64; + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + transformer + .apply( + requestResponseRequesterMono, + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(stateAssert::hasSubscribedFlagOnly) + .then(() -> Assertions.assertThat(payload.refCnt()).isOne()) + .then(activeStreams::assertNoActiveStreams) + .thenRequest(1) + .then(() -> stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag()) + .then(() -> Assertions.assertThat(payload.refCnt()).isZero()) + .then(() -> activeStreams.assertHasStream(1, requestResponseRequesterMono))) + .verify(); + + // should not add anything to map + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(payload.refCnt()).isZero(); + + final ByteBuf frameFragment1 = sender.poll(); + FrameAssert.assertThat(frameFragment1) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET_WITH_METADATA) // 64 - 6 (frame headers) - 3 (encoded metadata + // length) - 3 frame length + .hasMetadata(Arrays.copyOf(metadata, 52)) + .hasData(Unpooled.EMPTY_BUFFER) + .hasFragmentsFollow() + .typeOf(FrameType.REQUEST_RESPONSE) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment2 = sender.poll(); + FrameAssert.assertThat(frameFragment2) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET_WITH_METADATA) // 64 - 6 (frame headers) - 3 (encoded metadata + // length) - 3 frame length + .hasMetadata(Arrays.copyOfRange(metadata, 52, 65)) + .hasData(Arrays.copyOf(data, 39)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment3 = sender.poll(); + FrameAssert.assertThat(frameFragment3) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET) // 64 - 6 (frame headers) - 3 frame length (no metadata - no length) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 39, 94)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment4 = sender.poll(); + FrameAssert.assertThat(frameFragment4) + .isNotNull() + .hasPayloadSize(35) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 94, 129)) + .hasNoFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + if (!sender.isEmpty()) { + FrameAssert.assertThat(sender.poll()) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + Assertions.assertThat(sender).isEmpty(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * General StateMachine transition test. Ensures that no fragment is sent if mono was cancelled + * before any requests + */ + @Test + public void shouldBeNoOpsOnCancel() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final Payload payload = ByteBufPayload.create("testData", "testMetadata"); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + activeStreams.assertNoActiveStreams(); + stateAssert.isUnsubscribed(); + + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(() -> stateAssert.hasSubscribedFlagOnly()) + .then(() -> activeStreams.assertNoActiveStreams()) + .thenCancel() + .verify(); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * General state machine test Ensures that a Subscriber receives error signal and state migrate to + * the terminated in case the given payload is an invalid one. + */ + @ParameterizedTest + @MethodSource("shouldErrorOnIncorrectRefCntInGivenPayloadSource") + public void shouldErrorOnIncorrectRefCntInGivenPayload( + Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final Payload payload = ByteBufPayload.create(""); + payload.release(); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(requestResponseRequesterMono); + + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender).isEmpty(); + allocator.assertHasNoLeaks(); + } + + static Stream> + shouldErrorOnIncorrectRefCntInGivenPayloadSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .expectError(IllegalReferenceCountException.class) + .verify(), + requestResponseRequesterMono -> + Assertions.assertThatThrownBy(requestResponseRequesterMono::block) + .isInstanceOf(IllegalReferenceCountException.class)); + } + + /** + * General state machine test Ensures that a Subscriber receives error signal and state migrate to + * the terminated in case the given payload was release in the middle of interaction. + * Fragmentation is disabled + */ + @Test + public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhase() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final Payload payload = ByteBufPayload.create(""); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(payload::release) + .thenRequest(1) + .expectError(IllegalReferenceCountException.class) + .verify(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * General state machine test Ensures that a Subscriber receives error signal and state migrate to + * the terminated in case the given payload was release in the middle of interaction. + * Fragmentation is enabled + */ + @Test + public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhaseWithFragmentation() { + final int mtu = 64; + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(payload::release) + .thenRequest(1) + .expectError(IllegalReferenceCountException.class) + .verify(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * General state machine test Ensures that a Subscriber receives error signal and state migrates + * to the terminated in case the given payload is too big with disabled fragmentation + */ + @ParameterizedTest + @MethodSource("shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource") + public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( + Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + + final byte[] metadata = new byte[FRAME_LENGTH_MASK]; + final byte[] data = new byte[FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + activeStreams.assertNoActiveStreams(); + stateAssert.isUnsubscribed(); + + monoConsumer.accept(requestResponseRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender).isEmpty(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + static Stream> + shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)) + .verify(), + requestResponseRequesterMono -> + Assertions.assertThatThrownBy(requestResponseRequesterMono::block) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)); + } + + /** + * Ensures that error check happens exactly before frame sent. This cases ensures that in case no + * lease / other external errors appeared, the local subscriber received the same one. No frames + * should be sent + */ + @ParameterizedTest + @MethodSource("shouldErrorIfNoAvailabilitySource") + public void shouldErrorIfNoAvailability(Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(new RuntimeException("test")); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final Payload payload = genericPayload(allocator); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + activeStreams.assertNoActiveStreams(); + stateAssert.isUnsubscribed(); + + monoConsumer.accept(requestResponseRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + static Stream> shouldErrorIfNoAvailabilitySource() { + return Stream.of( + (s) -> + StepVerifier.create(s, 0) + .expectSubscription() + .then(() -> StateAssert.assertThat(s).hasSubscribedFlagOnly()) + .thenRequest(1) + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)) + .verify(), + requestResponseRequesterMono -> + Assertions.assertThatThrownBy(requestResponseRequesterMono::block) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)); + } + + @Test + public void checkName() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload payload = genericPayload(allocator); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + + Assertions.assertThat(Scannable.from(requestResponseRequesterMono).name()) + .isEqualTo("source(RequestResponseMono)"); + requestResponseRequesterMono.cancel(); + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestStreamRequesterFluxTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestStreamRequesterFluxTest.java new file mode 100644 index 000000000..9791b0786 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestStreamRequesterFluxTest.java @@ -0,0 +1,1146 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.Scannable; +import reactor.test.StepVerifier; + +public class RequestStreamRequesterFluxTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /* + * +-------------------------------+ + * | General Test Cases | + * +-------------------------------+ + */ + + /** + * State Machine check. Ensure migration from + * + *

+   * UNSUBSCRIBED -> SUBSCRIBED
+   * SUBSCRIBED -> REQUESTED(1) -> REQUESTED(0)
+   * REQUESTED(0) -> REQUESTED(1) -> REQUESTED(0)
+   * REQUESTED(0) -> REQUESTED(MAX)
+   * REQUESTED(MAX) -> REQUESTED(MAX) && REASSEMBLY (extra flag enabled which indicates
+   * reassembly)
+   * REQUESTED(MAX) && REASSEMBLY -> TERMINATED
+   * 
+ */ + @Test + public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestStreamRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(1); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertHasStream(1, requestStreamRequesterFlux); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + + final ByteBuf frame = sender.poll(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + final ByteBuf requestNFrame = sender.poll(); + FrameAssert.assertThat(requestNFrame) + .isNotNull() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check. Request N Frame should sent so request field should be 0 + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(2).hasFirstFrameSentFlag(); + + assertSubscriber.request(Long.MAX_VALUE); + final ByteBuf requestMaxNFrame = sender.poll(); + FrameAssert.assertThat(requestMaxNFrame) + .isNotNull() + .hasRequestN(Integer.MAX_VALUE) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + assertSubscriber.request(6); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, randomPayload); + ByteBuf firstFragment = fragments.remove(0); + requestStreamRequesterFlux.handleNext(firstFragment, true, false); + firstFragment.release(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag(); + + for (int i = 0; i < fragments.size(); i++) { + boolean hasFollowing = i != fragments.size() - 1; + ByteBuf followingFragment = fragments.get(i); + + requestStreamRequesterFlux.handleNext(followingFragment, hasFollowing, false); + followingFragment.release(); + } + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag(); + + Payload finalRandomPayload = TestRequesterResponderSupport.randomPayload(allocator); + requestStreamRequesterFlux.handlePayload(finalRandomPayload); + requestStreamRequesterFlux.handleComplete(); + + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isEqualTo(finalRandomPayload).hasNoLeaks()) + .assertComplete(); + + PayloadAssert.assertThat(randomPayload).hasNoLeaks(); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * State Machine check. Ensure migration from + * + *
+   * UNSUBSCRIBED -> SUBSCRIBED
+   * SUBSCRIBED -> REQUESTED(MAX)
+   * REQUESTED(MAX) -> TERMINATED
+   * 
+ */ + @Test + public void requestNFrameShouldBeSentExactlyOnceIfItIsMaxAllowed() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestStreamRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Long.MAX_VALUE / 2 + 1); + + // state machine check + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertHasStream(1, requestStreamRequesterFlux); + + final ByteBuf frame = sender.poll(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(Integer.MAX_VALUE) + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + requestStreamRequesterFlux.handlePayload(EmptyPayload.INSTANCE); + requestStreamRequesterFlux.handleComplete(); + + assertSubscriber.assertValues(EmptyPayload.INSTANCE).assertComplete(); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + /** + * State Machine check. Ensure migration from + * + *
+   * UNSUBSCRIBED -> SUBSCRIBED
+   * SUBSCRIBED -> REQUESTED(1) -> REQUESTED(0)
+   * 
+ * + * And then for the following cases: + * + *
+   * [0]: REQUESTED(0) -> REQUESTED(MAX) (with onNext and few extra request(1) which should not
+   * affect state anyhow and should not sent any extra frames)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [1]: REQUESTED(0) -> REQUESTED(MAX) (with onComplete rightaway)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [2]: REQUESTED(0) -> REQUESTED(MAX) (with onError rightaway)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [3]: REQUESTED(0) -> REASSEMBLY
+   *      REASSEMBLY -> REASSEMBLY && REQUESTED(MAX)
+   *      REASSEMBLY && REQUESTED(MAX) -> REQUESTED(MAX)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [4]: REQUESTED(0) -> REQUESTED(MAX)
+   *      REQUESTED(MAX) -> REASSEMBLY && REQUESTED(MAX)
+   *      REASSEMBLY && REQUESTED(MAX) -> TERMINATED (because of cancel() invocation)
+   * 
+ */ + @ParameterizedTest + @MethodSource("frameShouldBeSentOnFirstRequestResponses") + public void frameShouldBeSentOnFirstRequest( + BiFunction, StepVerifier> + transformer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + transformer + .apply( + requestStreamRequesterFlux, + StepVerifier.create(requestStreamRequesterFlux, 0) + .expectSubscription() + .then( + () -> + // state machine check + stateAssert.hasSubscribedFlagOnly()) + .then(() -> Assertions.assertThat(payload.refCnt()).isOne()) + .then(() -> activeStreams.assertNoActiveStreams()) + .thenRequest(1) + .then( + () -> + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag()) + .then(() -> Assertions.assertThat(payload.refCnt()).isZero()) + .then(() -> activeStreams.assertHasStream(1, requestStreamRequesterFlux))) + .verify(); + + Assertions.assertThat(payload.refCnt()).isZero(); + // should not add anything to map + activeStreams.assertNoActiveStreams(); + + final ByteBuf frame = sender.poll(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf requestNFrame = sender.poll(); + FrameAssert.assertThat(requestNFrame) + .isNotNull() + .typeOf(FrameType.REQUEST_N) + .hasRequestN(Integer.MAX_VALUE) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + if (!sender.isEmpty()) { + final ByteBuf cancelFrame = sender.poll(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + // state machine check + stateAssert.isTerminated(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + static Stream, StepVerifier>> + frameShouldBeSentOnFirstRequestResponses() { + return Stream.of( + (rsf, sv) -> + sv.then(() -> rsf.handlePayload(EmptyPayload.INSTANCE)) + .expectNext(EmptyPayload.INSTANCE) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .then(() -> rsf.handlePayload(EmptyPayload.INSTANCE)) + .thenRequest(1L) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .expectNext(EmptyPayload.INSTANCE) + .thenRequest(1L) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .then(rsf::handleComplete) + .thenRequest(1L) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf).isTerminated()) + .expectComplete(), + (rsf, sv) -> + sv.then(() -> rsf.handlePayload(EmptyPayload.INSTANCE)) + .expectNext(EmptyPayload.INSTANCE) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .then(rsf::handleComplete) + .thenRequest(1L) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf).isTerminated()) + .expectComplete(), + (rsf, sv) -> + sv.then(() -> rsf.handlePayload(EmptyPayload.INSTANCE)) + .expectNext(EmptyPayload.INSTANCE) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .then(() -> rsf.handleError(new ApplicationErrorException("test"))) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf).isTerminated()) + .thenRequest(1L) + .thenRequest(1L) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(ApplicationErrorException.class)), + (rsf, sv) -> { + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + final Payload payload2 = ByteBufPayload.create(data, metadata); + + return sv.then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFirstFragment( + rsf.allocator, + 64, + FrameType.NEXT, + 1, + payload.hasMetadata(), + payload.metadata(), + payload.data()); + rsf.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(1) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()); + rsf.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(1) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()); + rsf.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(1) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(Integer.MAX_VALUE) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()); + rsf.handleNext(followingFrame, false, false); + followingFrame.release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(Integer.MAX_VALUE) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag()) + .assertNext( + p -> { + Assertions.assertThat(p.data()).isEqualTo(Unpooled.wrappedBuffer(data)); + + Assertions.assertThat(p.metadata()).isEqualTo(Unpooled.wrappedBuffer(metadata)); + Assertions.assertThat(p.release()).isTrue(); + Assertions.assertThat(p.refCnt()).isZero(); + }) + .then(payload::release) + .then(() -> rsf.handlePayload(payload2)) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(Integer.MAX_VALUE) + .hasSubscribedFlag() + .hasFirstFrameSentFlag()) + .assertNext( + p -> { + Assertions.assertThat(p.data()).isEqualTo(Unpooled.wrappedBuffer(data)); + + Assertions.assertThat(p.metadata()).isEqualTo(Unpooled.wrappedBuffer(metadata)); + Assertions.assertThat(p.release()).isTrue(); + Assertions.assertThat(p.refCnt()).isZero(); + }) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(Integer.MAX_VALUE) + .hasSubscribedFlag() + .hasFirstFrameSentFlag()) + .then(rsf::handleComplete) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf).isTerminated()) + .expectComplete(); + }, + (rsf, sv) -> { + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload0 = ByteBufPayload.create(data, metadata); + final Payload payload = ByteBufPayload.create(data, metadata); + + ByteBuf[] fragments = + new ByteBuf[] { + FragmentationUtils.encodeFirstFragment( + rsf.allocator, + 64, + FrameType.NEXT, + 1, + payload.hasMetadata(), + payload.metadata(), + payload.data()), + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()), + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()) + }; + + final StepVerifier stepVerifier = + sv.then(() -> rsf.handlePayload(payload0)) + .assertNext(p -> PayloadAssert.assertThat(p).isEqualTo(payload0).hasNoLeaks()) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag()) + .then( + () -> { + rsf.handleNext(fragments[0], true, false); + fragments[0].release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + rsf.handleNext(fragments[1], true, false); + fragments[1].release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + rsf.handleNext(fragments[2], true, false); + fragments[2].release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then(payload::release) + .thenCancel() + .verifyLater(); + + stepVerifier.verify(); + // state machine check + StateAssert.assertThat(rsf).isTerminated(); + + Assertions.assertThat(fragments).allMatch(bb -> bb.refCnt() == 0); + + return stepVerifier; + }); + } + + /** + * State Machine check with fragmentation of the first payload. Ensure migration from + * + *
+   * UNSUBSCRIBED -> SUBSCRIBED
+   * SUBSCRIBED -> REQUESTED(1) -> REQUESTED(0)
+   * 
+ * + * And then for the following cases: + * + *
+   * [0]: REQUESTED(0) -> REQUESTED(MAX) (with onNext and few extra request(1) which should not
+   * affect state anyhow and should not sent any extra frames)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [1]: REQUESTED(0) -> REQUESTED(MAX) (with onComplete rightaway)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [2]: REQUESTED(0) -> REQUESTED(MAX) (with onError rightaway)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [3]: REQUESTED(0) -> REASSEMBLY
+   *      REASSEMBLY -> REASSEMBLY && REQUESTED(MAX)
+   *      REASSEMBLY && REQUESTED(MAX) -> REQUESTED(MAX)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [4]: REQUESTED(0) -> REQUESTED(MAX)
+   *      REQUESTED(MAX) -> REASSEMBLY && REQUESTED(MAX)
+   *      REASSEMBLY && REQUESTED(MAX) -> TERMINATED (because of cancel() invocation)
+   * 
+ */ + @ParameterizedTest + @MethodSource("frameShouldBeSentOnFirstRequestResponses") + public void frameFragmentsShouldBeSentOnFirstRequest( + BiFunction, StepVerifier> + transformer) { + final int mtu = 64; + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + transformer + .apply( + requestStreamRequesterFlux, + StepVerifier.create(requestStreamRequesterFlux, 0) + .expectSubscription() + .then(() -> Assertions.assertThat(payload.refCnt()).isOne()) + .then(() -> activeStreams.assertNoActiveStreams()) + .thenRequest(1) + .then(() -> Assertions.assertThat(payload.refCnt()).isZero()) + .then(() -> activeStreams.assertHasStream(1, requestStreamRequesterFlux))) + .verify(); + + // should not add anything to map + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(payload.refCnt()).isZero(); + + final ByteBuf frameFragment1 = sender.poll(); + FrameAssert.assertThat(frameFragment1) + .isNotNull() + .hasPayloadSize(64 - FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N) + // InitialRequestN size + .hasMetadata(Arrays.copyOf(metadata, 64 - FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N)) + .hasData(Unpooled.EMPTY_BUFFER) + .hasFragmentsFollow() + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment2 = sender.poll(); + FrameAssert.assertThat(frameFragment2) + .isNotNull() + .hasPayloadSize(64 - FRAME_OFFSET_WITH_METADATA) + .hasMetadata( + Arrays.copyOfRange(metadata, 64 - FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N, 65)) + .hasData(Arrays.copyOf(data, 35)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment3 = sender.poll(); + FrameAssert.assertThat(frameFragment3) + .isNotNull() + .hasPayloadSize(64 - FRAME_OFFSET) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 35, 35 + 55)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment4 = sender.poll(); + FrameAssert.assertThat(frameFragment4) + .isNotNull() + .hasPayloadSize(39) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 90, 129)) + .hasNoFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf requestNFrame = sender.poll(); + FrameAssert.assertThat(requestNFrame) + .isNotNull() + .typeOf(FrameType.REQUEST_N) + .hasRequestN(Integer.MAX_VALUE) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + if (!sender.isEmpty()) { + FrameAssert.assertThat(sender.poll()) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + Assertions.assertThat(sender).isEmpty(); + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * Case which ensures that if Payload has incorrect refCnt, the flux ends up with an appropriate + * error + */ + @ParameterizedTest + @MethodSource("shouldErrorOnIncorrectRefCntInGivenPayloadSource") + public void shouldErrorOnIncorrectRefCntInGivenPayload( + Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final Payload payload = ByteBufPayload.create(""); + payload.release(); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(requestStreamRequesterFlux); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender).isEmpty(); + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + static Stream> + shouldErrorOnIncorrectRefCntInGivenPayloadSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .expectError(IllegalReferenceCountException.class) + .verify(), + requestStreamRequesterFlux -> + Assertions.assertThatThrownBy(requestStreamRequesterFlux::blockLast) + .isInstanceOf(IllegalReferenceCountException.class)); + } + + /** + * Ensures that if Payload is release right after the subscription, the first request will exponse + * the error immediatelly and no frame will be sent to the remote party + */ + @Test + public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhase() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final Payload payload = ByteBufPayload.create(""); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + StepVerifier.create(requestStreamRequesterFlux, 0) + .expectSubscription() + .then( + () -> + // state machine check + stateAssert.hasSubscribedFlagOnly()) + .then(payload::release) + .thenRequest(1) + .then( + () -> + // state machine check + stateAssert.isTerminated()) + .expectError(IllegalReferenceCountException.class) + .verify(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender).isEmpty(); + + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * Ensures that if Payload is release right after the subscription, the first request will expose + * the error immediately and no frame will be sent to the remote party + */ + @Test + public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhaseWithFragmentation() { + final int mtu = 64; + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + StepVerifier.create(requestStreamRequesterFlux, 0) + .expectSubscription() + .then( + () -> + // state machine check + stateAssert.hasSubscribedFlagOnly()) + .then(payload::release) + .thenRequest(1) + .then( + () -> + // state machine check + stateAssert.isTerminated()) + .expectError(IllegalReferenceCountException.class) + .verify(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender).isEmpty(); + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * Ensures that if the given payload is exits 16mb size with disabled fragmentation, than the + * appropriate validation happens and a corresponding error will be propagagted to the subscriber + */ + @ParameterizedTest + @MethodSource("shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource") + public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( + Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + + final byte[] metadata = new byte[FRAME_LENGTH_MASK]; + final byte[] data = new byte[FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(requestStreamRequesterFlux); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender).isEmpty(); + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + static Stream> + shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource() { + return Stream.of( + (s) -> + StepVerifier.create(s, 0) + .expectSubscription() + .then( + () -> + // state machine check + StateAssert.assertThat(s).isTerminated()) + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)) + .verify(), + requestStreamRequesterFlux -> + Assertions.assertThatThrownBy(requestStreamRequesterFlux::blockLast) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)); + } + + /** + * Ensures that the interactions check and respect rsocket availability (such as leasing) and + * propagate an error to the final subscriber. No frame should be sent. Check should happens + * exactly on the first request. + */ + @ParameterizedTest + @MethodSource("shouldErrorIfNoAvailabilitySource") + public void shouldErrorIfNoAvailability(Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(new RuntimeException("test")); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(requestStreamRequesterFlux); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + allocator.assertHasNoLeaks(); + } + + static Stream> shouldErrorIfNoAvailabilitySource() { + return Stream.of( + (s) -> + StepVerifier.create(s, 0) + .expectSubscription() + .then( + () -> + // state machine check + StateAssert.assertThat(s).hasSubscribedFlagOnly()) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(s).isTerminated()) + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)) + .verify(), + requestStreamRequesterFlux -> + Assertions.assertThatThrownBy(requestStreamRequesterFlux::blockLast) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)); + } + + @Test + public void checkName() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + + Assertions.assertThat(Scannable.from(requestStreamRequesterFlux).name()) + .isEqualTo("source(RequestStreamFlux)"); + requestStreamRequesterFlux.cancel(); + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java new file mode 100644 index 000000000..8aee36467 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java @@ -0,0 +1,702 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameType.COMPLETE; +import static io.rsocket.frame.FrameType.METADATA_PUSH; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_N; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.util.CharsetUtil; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Supplier; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Publisher; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.test.StepVerifier; +import reactor.test.util.RaceTestUtils; + +@SuppressWarnings("ALL") +public class RequesterOperatorsRacingTest { + + interface Scenario { + FrameType requestType(); + + Publisher requestOperator( + Supplier payloadsSupplier, RequesterResponderSupport requesterResponderSupport); + } + + static Stream scenarios() { + return Stream.of( + new Scenario() { + @Override + public FrameType requestType() { + return METADATA_PUSH; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new MetadataPushRequesterMono(payloadsSupplier.get(), requesterResponderSupport); + } + + @Override + public String toString() { + return MetadataPushRequesterMono.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_FNF; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new FireAndForgetRequesterMono( + payloadsSupplier.get(), requesterResponderSupport); + } + + @Override + public String toString() { + return FireAndForgetRequesterMono.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_RESPONSE; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new RequestResponseRequesterMono( + payloadsSupplier.get(), requesterResponderSupport); + } + + @Override + public String toString() { + return RequestResponseRequesterMono.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_STREAM; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new RequestStreamRequesterFlux( + payloadsSupplier.get(), requesterResponderSupport); + } + + @Override + public String toString() { + return RequestStreamRequesterFlux.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_CHANNEL; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new RequestChannelRequesterFlux( + Flux.generate(s -> s.next(payloadsSupplier.get())), requesterResponderSupport); + } + + @Override + public String toString() { + return RequestChannelRequesterFlux.class.getSimpleName(); + } + }); + } + + /* + * +--------------------------------+ + * | Racing Test Cases | + * +--------------------------------+ + */ + + /** Ensures single subscription happens in case of racing */ + @ParameterizedTest(name = "Should subscribe exactly once to {0}") + @MethodSource("scenarios") + public void shouldSubscribeExactlyOnce(Scenario scenario) { + for (int i = 0; i < 10000; i++) { + final TestRequesterResponderSupport requesterResponderSupport = + TestRequesterResponderSupport.client(); + final Supplier payloadSupplier = + () -> + TestRequesterResponderSupport.genericPayload( + requesterResponderSupport.getAllocator()); + + final Publisher requestOperator = + scenario.requestOperator(payloadSupplier, requesterResponderSupport); + + StepVerifier stepVerifier = + StepVerifier.create(requesterResponderSupport.getSendProcessor()) + .assertNext( + frame -> { + FrameAssert frameAssert = + FrameAssert.assertThat(frame) + .isNotNull() + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()); + if (scenario.requestType() == METADATA_PUSH) { + frameAssert + .hasStreamIdZero() + .hasPayloadSize( + TestRequesterResponderSupport.METADATA_CONTENT.getBytes( + CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT); + } else { + frameAssert + .hasClientSideStreamId() + .hasStreamId(1) + .hasPayloadSize( + TestRequesterResponderSupport.METADATA_CONTENT.getBytes( + CharsetUtil.UTF_8) + .length + + TestRequesterResponderSupport.DATA_CONTENT.getBytes( + CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT); + } + frameAssert.hasNoLeaks(); + + if (requestOperator instanceof FrameHandler) { + ((FrameHandler) requestOperator).handleComplete(); + } + }) + .thenCancel() + .verifyLater(); + + Assertions.assertThatThrownBy( + () -> + RaceTestUtils.race( + () -> { + AssertSubscriber subscriber = new AssertSubscriber<>(); + requestOperator.subscribe(subscriber); + subscriber.await().assertTerminated().assertNoError(); + }, + () -> { + AssertSubscriber subscriber = new AssertSubscriber<>(); + requestOperator.subscribe(subscriber); + subscriber.await().assertTerminated().assertNoError(); + })) + .matches( + t -> { + Assertions.assertThat(t).hasMessageContaining("allows only a single Subscriber"); + return true; + }); + + stepVerifier.verify(Duration.ofSeconds(1)); + Assertions.assertThat(requesterResponderSupport.getSendProcessor().isEmpty()).isTrue(); + requesterResponderSupport.getAllocator().assertHasNoLeaks(); + } + } + + /** Ensures single frame is sent only once racing between requests */ + @ParameterizedTest(name = "{0} should sent requestFrame exactly once if request(n) is racing") + @MethodSource("scenarios") + public void shouldSentRequestFrameOnceInCaseOfRequestRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + + for (int i = 0; i < 10000; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = + (Publisher) scenario.requestOperator(payloadSupplier, activeStreams); + + Payload response = ByteBufPayload.create("test", "test"); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + requestOperator.subscribe(assertSubscriber); + + RaceTestUtils.race(() -> assertSubscriber.request(1), () -> assertSubscriber.request(1)); + + final ByteBuf sentFrame = activeStreams.getSendProcessor().poll(); + + if (scenario.requestType().hasInitialRequestN()) { + if (RequestStreamFrameCodec.initialRequestN(sentFrame) == 1) { + FrameAssert.assertThat(activeStreams.getSendProcessor().poll()) + .isNotNull() + .hasStreamId(1) + .hasRequestN(1) + .typeOf(REQUEST_N) + .hasNoLeaks(); + } else { + Assertions.assertThat(RequestStreamFrameCodec.initialRequestN(sentFrame)).isEqualTo(2); + } + } + + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + ((RequesterFrameHandler) requestOperator).handlePayload(response); + ((RequesterFrameHandler) requestOperator).handleComplete(); + + if (scenario.requestType() == REQUEST_CHANNEL) { + ((CoreSubscriber) requestOperator).onComplete(); + FrameAssert.assertThat(activeStreams.getSendProcessor().poll()) + .typeOf(COMPLETE) + .hasStreamId(1) + .hasNoLeaks(); + } + + assertSubscriber + .assertTerminated() + .assertValuesWith( + p -> { + Assertions.assertThat(p.release()).isTrue(); + Assertions.assertThat(p.refCnt()).isZero(); + }); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(activeStreams.getSendProcessor().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } + + /** + * Ensures that no ByteBuf is leaked if reassembly is starting and cancel is happening at the same + * time + */ + @ParameterizedTest(name = "Should have no leaks when {0} is canceled during reassembly") + @MethodSource("scenarios") + public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + + for (int i = 0; i < 10000; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = + (Publisher) scenario.requestOperator(payloadSupplier, activeStreams); + + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(1); + + requestOperator.subscribe(assertSubscriber); + + final ByteBuf sentFrame = activeStreams.getSendProcessor().poll(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload responsePayload = + TestRequesterResponderSupport.randomPayload(activeStreams.getAllocator()); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments( + activeStreams.getAllocator(), mtu, responsePayload); + RaceTestUtils.race( + assertSubscriber::cancel, + () -> { + FrameHandler frameHandler = (FrameHandler) requestOperator; + int lastFragmentId = fragments.size() - 1; + for (int j = 0; j < fragments.size(); j++) { + ByteBuf frame = fragments.get(j); + frameHandler.handleNext(frame, lastFragmentId != j, lastFragmentId == j); + frame.release(); + } + }); + + List values = assertSubscriber.values(); + if (!values.isEmpty()) { + Assertions.assertThat(values) + .hasSize(1) + .first() + .matches( + p -> { + Assertions.assertThat(p.sliceData()) + .matches(bb -> ByteBufUtil.equals(bb, responsePayload.sliceData())); + Assertions.assertThat(p.hasMetadata()).isEqualTo(responsePayload.hasMetadata()); + Assertions.assertThat(p.sliceMetadata()) + .matches(bb -> ByteBufUtil.equals(bb, responsePayload.sliceMetadata())); + Assertions.assertThat(p.release()).isTrue(); + Assertions.assertThat(p.refCnt()).isZero(); + return true; + }); + } + + if (!activeStreams.getSendProcessor().isEmpty()) { + if (scenario.requestType() != REQUEST_CHANNEL) { + assertSubscriber.assertNotTerminated(); + } + + final ByteBuf cancellationFrame = activeStreams.getSendProcessor().poll(); + FrameAssert.assertThat(cancellationFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + + Assertions.assertThat(responsePayload.release()).isTrue(); + Assertions.assertThat(responsePayload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(activeStreams.getSendProcessor().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } + + /** + * Ensures that in case of racing between next element and cancel we will not have any memory + * leaks + */ + @Test + public void shouldHaveNoLeaksOnNextAndCancelRacing() { + for (int i = 0; i < 10000; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final Payload payload = + TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + + Payload response = ByteBufPayload.create("test", "test"); + + StepVerifier.create(requestResponseRequesterMono.doOnNext(Payload::release)) + .expectSubscription() + .expectComplete() + .verifyLater(); + + final ByteBuf sentFrame = activeStreams.getSendProcessor().poll(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(FrameType.REQUEST_RESPONSE) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + RaceTestUtils.race( + requestResponseRequesterMono::cancel, + () -> requestResponseRequesterMono.handlePayload(response)); + + Assertions.assertThat(payload.refCnt()).isZero(); + Assertions.assertThat(response.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + final boolean isEmpty = activeStreams.getSendProcessor().isEmpty(); + if (!isEmpty) { + final ByteBuf cancellationFrame = activeStreams.getSendProcessor().poll(); + FrameAssert.assertThat(cancellationFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + Assertions.assertThat(activeStreams.getSendProcessor().isEmpty()).isTrue(); + + StateAssert.assertThat(requestResponseRequesterMono).isTerminated(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } + + /** + * Ensures that in case we have element reassembling and then it happens the remote sends + * (errorFrame) and downstream subscriber sends cancel() and we have racing between onError and + * cancel we will not have any memory leaks + */ + @ParameterizedTest + @ValueSource(booleans = {false, true}) + public void shouldHaveNoUnexpectedErrorDuringOnErrorAndCancelRacing(boolean withReassembly) { + final ArrayList droppedErrors = new ArrayList<>(); + Hooks.onErrorDropped(droppedErrors::add); + try { + for (int i = 0; i < 10000; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final Payload payload = + TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + final AssertSubscriber assertSubscriber = + requestResponseRequesterMono.subscribeWith(AssertSubscriber.create(0)); + + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(1); + + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + + final ByteBuf sentFrame = activeStreams.getSendProcessor().poll(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(FrameType.REQUEST_RESPONSE) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + if (withReassembly) { + final ByteBuf fragmentBuf = + activeStreams.getAllocator().buffer().writeBytes(new byte[] {1, 2, 3}); + requestResponseRequesterMono.handleNext(fragmentBuf, true, false); + // mimic frameHandler behaviour + fragmentBuf.release(); + } + + final RuntimeException testException = new RuntimeException("test"); + RaceTestUtils.race( + requestResponseRequesterMono::cancel, + () -> requestResponseRequesterMono.handleError(testException)); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + stateAssert.isTerminated(); + + final boolean isEmpty = activeStreams.getSendProcessor().isEmpty(); + if (!isEmpty) { + final ByteBuf cancellationFrame = activeStreams.getSendProcessor().poll(); + FrameAssert.assertThat(cancellationFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(droppedErrors).containsExactly(testException); + } else { + assertSubscriber.assertTerminated().assertErrorMessage("test"); + } + Assertions.assertThat(activeStreams.getSendProcessor().isEmpty()).isTrue(); + + stateAssert.isTerminated(); + droppedErrors.clear(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } finally { + Hooks.resetOnErrorDropped(); + } + } + + /** + * Ensures that in case of racing between first request and cancel does not going to introduce + * leaks.
+ *
+ * + *

Please note, first request may or may not happen so in case it happened before cancellation + * signal we have to observe + * + *

    + *
  • RequestResponseFrame + *
  • CancellationFrame + *
+ * + *

exactly in that order + * + *

Ensures full serialization of outgoing signal (frames) + */ + @Test + public void shouldBeConsistentInCaseOfRacingOfCancellationAndRequest() { + for (int i = 0; i < 10000; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final Payload payload = + TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + + Payload response = ByteBufPayload.create("test", "test"); + + final AssertSubscriber assertSubscriber = + requestResponseRequesterMono.subscribeWith(new AssertSubscriber<>(0)); + + RaceTestUtils.race(() -> assertSubscriber.cancel(), () -> assertSubscriber.request(1)); + + if (!activeStreams.getSendProcessor().isEmpty()) { + final ByteBuf sentFrame = activeStreams.getSendProcessor().poll(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .typeOf(FrameType.REQUEST_RESPONSE) + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf cancelFrame = activeStreams.getSendProcessor().poll(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + + Assertions.assertThat(payload.refCnt()).isZero(); + + StateAssert.assertThat(requestResponseRequesterMono).isTerminated(); + + requestResponseRequesterMono.handlePayload(response); + Assertions.assertThat(response.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(activeStreams.getSendProcessor().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } + + /** Ensures that CancelFrame is sent exactly once in case of racing between cancel() methods */ + @Test + public void shouldSentCancelFrameExactlyOnce() { + for (int i = 0; i < 10000; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final Payload payload = + TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + + Payload response = ByteBufPayload.create("test", "test"); + + final AssertSubscriber assertSubscriber = + requestResponseRequesterMono.subscribeWith(new AssertSubscriber<>(0)); + + assertSubscriber.request(1); + + final ByteBuf sentFrame = activeStreams.getSendProcessor().poll(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasNoFragmentsFollow() + .typeOf(FrameType.REQUEST_RESPONSE) + .hasClientSideStreamId() + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasStreamId(1) + .hasNoLeaks(); + + RaceTestUtils.race( + requestResponseRequesterMono::cancel, requestResponseRequesterMono::cancel); + + final ByteBuf cancelFrame = activeStreams.getSendProcessor().poll(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertNoActiveStreams(); + + StateAssert.assertThat(requestResponseRequesterMono).isTerminated(); + + requestResponseRequesterMono.handlePayload(response); + Assertions.assertThat(response.refCnt()).isZero(); + + requestResponseRequesterMono.handleComplete(); + assertSubscriber.assertNotTerminated(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(activeStreams.getSendProcessor().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java b/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java new file mode 100755 index 000000000..2872d8d78 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java @@ -0,0 +1,413 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameType.METADATA_PUSH; +import static io.rsocket.frame.FrameType.NEXT; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; + +import io.netty.buffer.ByteBuf; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.internal.subscriber.AssertSubscriber; +import java.util.ArrayList; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.test.publisher.TestPublisher; + +public class ResponderOperatorsCommonTest { + + interface Scenario { + FrameType requestType(); + + int maxElements(); + + ResponderFrameHandler responseOperator( + long initialRequestN, + Payload firstPayload, + TestRequesterResponderSupport streamManager, + RSocket handler); + + ResponderFrameHandler responseOperator( + long initialRequestN, + ByteBuf firstFragment, + TestRequesterResponderSupport streamManager, + RSocket handler); + } + + static Stream scenarios() { + return Stream.of( + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_RESPONSE; + } + + @Override + public int maxElements() { + return 1; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + ByteBuf firstFragment, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber( + streamId, firstFragment, streamManager, handler); + streamManager.activeStreams.put(streamId, subscriber); + return subscriber; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + Payload firstPayload, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber(streamId, streamManager); + streamManager.activeStreams.put(streamId, subscriber); + return handler.requestResponse(firstPayload).subscribeWith(subscriber); + } + + @Override + public String toString() { + return RequestResponseRequesterMono.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_STREAM; + } + + @Override + public int maxElements() { + return Integer.MAX_VALUE; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + ByteBuf firstFragment, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber( + streamId, initialRequestN, firstFragment, streamManager, handler); + streamManager.activeStreams.put(streamId, subscriber); + return subscriber; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + Payload firstPayload, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber(streamId, initialRequestN, streamManager); + streamManager.activeStreams.put(streamId, subscriber); + return handler.requestStream(firstPayload).subscribeWith(subscriber); + } + + @Override + public String toString() { + return RequestStreamResponderSubscriber.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_CHANNEL; + } + + @Override + public int maxElements() { + return Integer.MAX_VALUE; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + ByteBuf firstFragment, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestChannelResponderSubscriber subscriber = + new RequestChannelResponderSubscriber( + streamId, initialRequestN, firstFragment, streamManager, handler); + streamManager.activeStreams.put(streamId, subscriber); + return subscriber; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + Payload firstPayload, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestChannelResponderSubscriber responderSubscriber = + new RequestChannelResponderSubscriber( + streamId, initialRequestN, firstPayload, streamManager); + streamManager.activeStreams.put(streamId, responderSubscriber); + return handler.requestChannel(responderSubscriber).subscribeWith(responderSubscriber); + } + + @Override + public String toString() { + return RequestChannelResponderSubscriber.class.getSimpleName(); + } + }); + } + + static class TestHandler implements RSocket { + + final TestPublisher producer; + final AssertSubscriber consumer; + + TestHandler(TestPublisher producer, AssertSubscriber consumer) { + this.producer = producer; + this.consumer = consumer; + } + + @Override + public Mono fireAndForget(Payload payload) { + consumer.onSubscribe(Operators.emptySubscription()); + consumer.onNext(payload); + consumer.onComplete(); + return producer.mono().then(); + } + + @Override + public Mono requestResponse(Payload payload) { + consumer.onSubscribe(Operators.emptySubscription()); + consumer.onNext(payload); + consumer.onComplete(); + return producer.mono(); + } + + @Override + public Flux requestStream(Payload payload) { + consumer.onSubscribe(Operators.emptySubscription()); + consumer.onNext(payload); + consumer.onComplete(); + return producer.flux(); + } + + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(consumer); + return producer.flux(); + } + } + + @ParameterizedTest + @MethodSource("scenarios") + void shouldHandleRequest(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + + TestRequesterResponderSupport testRequesterResponderSupport = + TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); + final UnboundedProcessor sender = testRequesterResponderSupport.getSendProcessor(); + TestPublisher testPublisher = TestPublisher.create(); + TestHandler testHandler = new TestHandler(testPublisher, new AssertSubscriber<>(0)); + + ResponderFrameHandler responderFrameHandler = + scenario.responseOperator( + Long.MAX_VALUE, + TestRequesterResponderSupport.genericPayload(allocator), + testRequesterResponderSupport, + testHandler); + + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + testPublisher.assertWasSubscribed(); + testPublisher.next(randomPayload.retain()); + testPublisher.complete(); + + FrameAssert.assertThat(sender.poll()) + .isNotNull() + .hasStreamId(1) + .typeOf(scenario.requestType() == REQUEST_RESPONSE ? FrameType.NEXT_COMPLETE : NEXT) + .hasPayloadSize( + randomPayload.data().readableBytes() + randomPayload.sliceMetadata().readableBytes()) + .hasData(randomPayload.data()) + .hasNoLeaks(); + + PayloadAssert.assertThat(randomPayload).hasNoLeaks(); + + if (scenario.requestType() != REQUEST_RESPONSE) { + + FrameAssert.assertThat(sender.poll()).typeOf(FrameType.COMPLETE).hasStreamId(1).hasNoLeaks(); + + if (scenario.requestType() == REQUEST_CHANNEL) { + testHandler.consumer.request(2); + FrameAssert.assertThat(sender.poll()) + .typeOf(FrameType.REQUEST_N) + .hasStreamId(1) + .hasRequestN(1) + .hasNoLeaks(); + } + } + + testHandler + .consumer + .assertValueCount(1) + .assertValuesWith(p -> PayloadAssert.assertThat(p).hasNoLeaks()); + + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("scenarios") + void shouldHandleFragmentedRequest(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + + TestRequesterResponderSupport testRequesterResponderSupport = + TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); + final UnboundedProcessor sender = testRequesterResponderSupport.getSendProcessor(); + TestPublisher testPublisher = TestPublisher.create(); + TestHandler testHandler = new TestHandler(testPublisher, new AssertSubscriber<>(0)); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, firstPayload); + + ByteBuf firstFragment = fragments.remove(0); + ResponderFrameHandler responderFrameHandler = + scenario.responseOperator( + Long.MAX_VALUE, firstFragment, testRequesterResponderSupport, testHandler); + firstFragment.release(); + + testPublisher.assertWasNotSubscribed(); + testRequesterResponderSupport.assertHasStream(1, responderFrameHandler); + + for (int i = 0; i < fragments.size(); i++) { + ByteBuf fragment = fragments.get(i); + boolean hasFollows = i != fragments.size() - 1; + responderFrameHandler.handleNext(fragment, hasFollows, !hasFollows); + fragment.release(); + } + + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + testPublisher.assertWasSubscribed(); + testPublisher.next(randomPayload.retain()); + testPublisher.complete(); + + FrameAssert.assertThat(sender.poll()) + .isNotNull() + .hasStreamId(1) + .typeOf(scenario.requestType() == REQUEST_RESPONSE ? FrameType.NEXT_COMPLETE : NEXT) + .hasPayloadSize( + randomPayload.data().readableBytes() + randomPayload.sliceMetadata().readableBytes()) + .hasData(randomPayload.data()) + .hasNoLeaks(); + + PayloadAssert.assertThat(randomPayload).hasNoLeaks(); + + if (scenario.requestType() != REQUEST_RESPONSE) { + + FrameAssert.assertThat(sender.poll()).typeOf(FrameType.COMPLETE).hasStreamId(1).hasNoLeaks(); + + if (scenario.requestType() == REQUEST_CHANNEL) { + testHandler.consumer.request(2); + FrameAssert.assertThat(sender.poll()).isNull(); + } + } + + testHandler + .consumer + .assertValueCount(1) + .assertValuesWith( + p -> PayloadAssert.assertThat(p).hasData(firstPayload.sliceData()).hasNoLeaks()) + .assertComplete(); + + testRequesterResponderSupport.assertNoActiveStreams(); + + firstPayload.release(); + + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("scenarios") + void shouldHandleInterruptedFragmentation(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + + TestRequesterResponderSupport testRequesterResponderSupport = + TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); + final UnboundedProcessor sender = testRequesterResponderSupport.getSendProcessor(); + TestPublisher testPublisher = TestPublisher.create(); + TestHandler testHandler = new TestHandler(testPublisher, new AssertSubscriber<>(0)); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, firstPayload); + firstPayload.release(); + + ByteBuf firstFragment = fragments.remove(0); + ResponderFrameHandler responderFrameHandler = + scenario.responseOperator( + Long.MAX_VALUE, firstFragment, testRequesterResponderSupport, testHandler); + firstFragment.release(); + + testPublisher.assertWasNotSubscribed(); + testRequesterResponderSupport.assertHasStream(1, responderFrameHandler); + + for (int i = 0; i < fragments.size(); i++) { + ByteBuf fragment = fragments.get(i); + boolean hasFollows = i != fragments.size() - 1; + if (hasFollows) { + responderFrameHandler.handleNext(fragment, true, false); + } else { + responderFrameHandler.handleCancel(); + } + fragment.release(); + } + + testPublisher.assertWasNotSubscribed(); + testRequesterResponderSupport.assertNoActiveStreams(); + + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java index fe53b7df4..a64bf9b81 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java @@ -59,11 +59,11 @@ void requesterStreamsTerminatedOnZeroErrorFrame() { StreamIdSupplier.clientSupplier(), 0, FRAME_LENGTH_MASK, + Integer.MAX_VALUE, 0, 0, null, - RequesterLeaseHandler.None, - TestScheduler.INSTANCE); + RequesterLeaseHandler.None); String errorMsg = "error"; @@ -96,11 +96,11 @@ void requesterNewStreamsTerminatedAfterZeroErrorFrame() { StreamIdSupplier.clientSupplier(), 0, FRAME_LENGTH_MASK, + Integer.MAX_VALUE, 0, 0, null, - RequesterLeaseHandler.None, - TestScheduler.INSTANCE); + RequesterLeaseHandler.None); conn.addToReceivedBuffer( ErrorFrameCodec.encode(ByteBufAllocator.DEFAULT, 0, new RejectedSetupException("error"))); diff --git a/rsocket-core/src/test/java/io/rsocket/core/ShouldHaveFlag.java b/rsocket-core/src/test/java/io/rsocket/core/ShouldHaveFlag.java new file mode 100644 index 000000000..88e0dc8e2 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ShouldHaveFlag.java @@ -0,0 +1,98 @@ +package io.rsocket.core; + +import static io.rsocket.core.StateUtils.REQUEST_MASK; +import static io.rsocket.core.StateUtils.SUBSCRIBED_FLAG; +import static io.rsocket.core.StateUtils.extractRequestN; + +import java.util.HashMap; +import java.util.Map; +import org.assertj.core.error.BasicErrorMessageFactory; +import org.assertj.core.error.ErrorMessageFactory; + +class ShouldHaveFlag extends BasicErrorMessageFactory { + + static final Map FLAGS_NAMES = + new HashMap() { + { + put(StateUtils.UNSUBSCRIBED_STATE, "UNSUBSCRIBED"); + put(StateUtils.TERMINATED_STATE, "TERMINATED"); + put(SUBSCRIBED_FLAG, "SUBSCRIBED"); + put(StateUtils.REQUEST_MASK, "REQUESTED(%s)"); + put(StateUtils.FIRST_FRAME_SENT_FLAG, "FIRST_FRAME_SENT"); + put(StateUtils.REASSEMBLING_FLAG, "REASSEMBLING"); + put(StateUtils.INBOUND_TERMINATED_FLAG, "INBOUND_TERMINATED"); + put(StateUtils.OUTBOUND_TERMINATED_FLAG, "OUTBOUND_TERMINATED"); + } + }; + + static final String SHOULD_HAVE_FLAG = "Expected state\n\t%s\nto have\n\t%s\nbut had\n\t[%s]"; + + private ShouldHaveFlag(long currentState, String expectedFlag, String actualFlags) { + super(SHOULD_HAVE_FLAG, toBinaryString(currentState), expectedFlag, actualFlags); + } + + static ErrorMessageFactory shouldHaveFlag(long currentState, long expectedFlag) { + String stateAsString = extractStateAsString(currentState); + return new ShouldHaveFlag(currentState, FLAGS_NAMES.get(expectedFlag), stateAsString); + } + + static ErrorMessageFactory shouldHaveRequestN(long currentState, long expectedRequestN) { + String stateAsString = extractStateAsString(currentState); + return new ShouldHaveFlag( + currentState, + String.format( + FLAGS_NAMES.get(REQUEST_MASK), + expectedRequestN == Integer.MAX_VALUE ? "MAX" : expectedRequestN), + stateAsString); + } + + static ErrorMessageFactory shouldHaveRequestNBetween( + long currentState, long expectedRequestNMin, long expectedRequestNMax) { + String stateAsString = extractStateAsString(currentState); + return new ShouldHaveFlag( + currentState, + String.format( + FLAGS_NAMES.get(REQUEST_MASK), + (expectedRequestNMin == Integer.MAX_VALUE ? "MAX" : expectedRequestNMin) + + " - " + + (expectedRequestNMax == Integer.MAX_VALUE ? "MAX" : expectedRequestNMax)), + stateAsString); + } + + private static String extractStateAsString(long currentState) { + StringBuilder stringBuilder = new StringBuilder(); + long flag = 1L << 31; + for (int i = 0; i < 33; i++, flag <<= 1) { + if ((currentState & flag) == flag) { + if (stringBuilder.length() > 0) { + stringBuilder.append(", "); + } + stringBuilder.append(FLAGS_NAMES.get(flag)); + } + } + long requestN = extractRequestN(currentState); + if (requestN > 0) { + if (stringBuilder.length() > 0) { + stringBuilder.append(", "); + } + stringBuilder.append( + String.format( + FLAGS_NAMES.get(REQUEST_MASK), requestN >= Integer.MAX_VALUE ? "MAX" : requestN)); + } + return stringBuilder.toString(); + } + + static String toBinaryString(long state) { + StringBuilder binaryString = new StringBuilder(Long.toBinaryString(state)); + + int diff = 64 - binaryString.length(); + for (int i = 0; i < diff; i++) { + binaryString.insert(0, "0"); + } + + binaryString.insert(33, "_"); + binaryString.insert(0, "0b"); + + return binaryString.toString(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ShouldNotHaveFlag.java b/rsocket-core/src/test/java/io/rsocket/core/ShouldNotHaveFlag.java new file mode 100644 index 000000000..e281e548c --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ShouldNotHaveFlag.java @@ -0,0 +1,73 @@ +package io.rsocket.core; + +import static io.rsocket.core.StateUtils.REQUEST_MASK; +import static io.rsocket.core.StateUtils.SUBSCRIBED_FLAG; +import static io.rsocket.core.StateUtils.extractRequestN; + +import java.util.HashMap; +import java.util.Map; +import org.assertj.core.error.BasicErrorMessageFactory; +import org.assertj.core.error.ErrorMessageFactory; + +class ShouldNotHaveFlag extends BasicErrorMessageFactory { + + static final Map FLAGS_NAMES = + new HashMap() { + { + put(StateUtils.UNSUBSCRIBED_STATE, "UNSUBSCRIBED"); + put(StateUtils.TERMINATED_STATE, "TERMINATED"); + put(SUBSCRIBED_FLAG, "SUBSCRIBED"); + put(StateUtils.REQUEST_MASK, "REQUESTED(%n)"); + put(StateUtils.FIRST_FRAME_SENT_FLAG, "FIRST_FRAME_SENT"); + put(StateUtils.REASSEMBLING_FLAG, "REASSEMBLING"); + put(StateUtils.INBOUND_TERMINATED_FLAG, "INBOUND_TERMINATED"); + put(StateUtils.OUTBOUND_TERMINATED_FLAG, "OUTBOUND_TERMINATED"); + } + }; + + static final String SHOULD_NOT_HAVE_FLAG = + "Expected state\n\t%s\nto not have\n\t%s\nbut had\n\t[%s]"; + + private ShouldNotHaveFlag(long currentState, long expectedFlag, String actualFlags) { + super( + SHOULD_NOT_HAVE_FLAG, + toBinaryString(currentState), + FLAGS_NAMES.get(expectedFlag), + actualFlags); + } + + static ErrorMessageFactory shouldNotHaveFlag(long currentState, long expectedFlag) { + StringBuilder stringBuilder = new StringBuilder(); + long flag = 1L << 31; + for (int i = 0; i < 33; i++, flag <<= 1) { + if ((currentState & flag) == flag) { + if (stringBuilder.length() > 0) { + stringBuilder.append(", "); + } + stringBuilder.append(FLAGS_NAMES.get(flag)); + } + } + long requestN = extractRequestN(currentState); + if (requestN > 0) { + if (stringBuilder.length() > 0) { + stringBuilder.append(", "); + } + stringBuilder.append(String.format(FLAGS_NAMES.get(REQUEST_MASK), requestN)); + } + return new ShouldNotHaveFlag(currentState, expectedFlag, stringBuilder.toString()); + } + + static String toBinaryString(long state) { + StringBuilder binaryString = new StringBuilder(Long.toBinaryString(state)); + + int diff = 64 - binaryString.length(); + for (int i = 0; i < diff; i++) { + binaryString.insert(0, "0"); + } + + binaryString.insert(33, "_"); + binaryString.insert(0, "0b"); + + return binaryString.toString(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/StateAssert.java b/rsocket-core/src/test/java/io/rsocket/core/StateAssert.java new file mode 100644 index 000000000..64253984b --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/StateAssert.java @@ -0,0 +1,161 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.ShouldHaveFlag.*; +import static io.rsocket.core.ShouldNotHaveFlag.shouldNotHaveFlag; +import static io.rsocket.core.StateUtils.*; + +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.assertj.core.api.AbstractAssert; +import org.assertj.core.internal.Failures; + +public class StateAssert extends AbstractAssert, AtomicLongFieldUpdater> { + + public static StateAssert assertThat(AtomicLongFieldUpdater updater, T instance) { + return new StateAssert<>(updater, instance); + } + + public static StateAssert assertThat( + FireAndForgetRequesterMono instance) { + return new StateAssert<>(FireAndForgetRequesterMono.STATE, instance); + } + + public static StateAssert assertThat( + RequestResponseRequesterMono instance) { + return new StateAssert<>(RequestResponseRequesterMono.STATE, instance); + } + + public static StateAssert assertThat( + RequestStreamRequesterFlux instance) { + return new StateAssert<>(RequestStreamRequesterFlux.STATE, instance); + } + + public static StateAssert assertThat( + RequestChannelRequesterFlux instance) { + return new StateAssert<>(RequestChannelRequesterFlux.STATE, instance); + } + + public static StateAssert assertThat( + RequestChannelResponderSubscriber instance) { + return new StateAssert<>(RequestChannelResponderSubscriber.STATE, instance); + } + + private final Failures failures = Failures.instance(); + private final T instance; + + public StateAssert(AtomicLongFieldUpdater updater, T instance) { + super(updater, StateAssert.class); + this.instance = instance; + } + + public StateAssert isUnsubscribed() { + long currentState = actual.get(instance); + if (isSubscribed(currentState) || StateUtils.isTerminated(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, UNSUBSCRIBED_STATE)); + } + return this; + } + + public StateAssert hasSubscribedFlagOnly() { + long currentState = actual.get(instance); + if (currentState != SUBSCRIBED_FLAG) { + throw failures.failure(info, shouldHaveFlag(currentState, SUBSCRIBED_FLAG)); + } + return this; + } + + public StateAssert hasSubscribedFlag() { + long currentState = actual.get(instance); + if (!isSubscribed(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, SUBSCRIBED_FLAG)); + } + return this; + } + + public StateAssert hasRequestN(long n) { + long currentState = actual.get(instance); + if (extractRequestN(currentState) != n) { + throw failures.failure(info, shouldHaveRequestN(currentState, n)); + } + return this; + } + + public StateAssert hasRequestNBetween(long min, long max) { + long currentState = actual.get(instance); + final long requestN = extractRequestN(currentState); + if (requestN < min || requestN > max) { + throw failures.failure(info, shouldHaveRequestNBetween(currentState, min, max)); + } + return this; + } + + public StateAssert hasFirstFrameSentFlag() { + long currentState = actual.get(instance); + if (!isFirstFrameSent(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, FIRST_FRAME_SENT_FLAG)); + } + return this; + } + + public StateAssert hasNoFirstFrameSentFlag() { + long currentState = actual.get(instance); + if (isFirstFrameSent(currentState)) { + throw failures.failure(info, shouldNotHaveFlag(currentState, FIRST_FRAME_SENT_FLAG)); + } + return this; + } + + public StateAssert hasReassemblingFlag() { + long currentState = actual.get(instance); + if (!isReassembling(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, REASSEMBLING_FLAG)); + } + return this; + } + + public StateAssert hasNoReassemblingFlag() { + long currentState = actual.get(instance); + if (isReassembling(currentState)) { + throw failures.failure(info, shouldNotHaveFlag(currentState, REASSEMBLING_FLAG)); + } + return this; + } + + public StateAssert hasInboundTerminated() { + long currentState = actual.get(instance); + if (!StateUtils.isInboundTerminated(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, INBOUND_TERMINATED_FLAG)); + } + return this; + } + + public StateAssert hasOutboundTerminated() { + long currentState = actual.get(instance); + if (!StateUtils.isOutboundTerminated(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, OUTBOUND_TERMINATED_FLAG)); + } + return this; + } + + public StateAssert isTerminated() { + long currentState = actual.get(instance); + if (!StateUtils.isTerminated(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, TERMINATED_STATE)); + } + return this; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java b/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java index 00248b6d8..98fde97f7 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java @@ -20,14 +20,14 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import io.netty.util.collection.IntObjectHashMap; import io.netty.util.collection.IntObjectMap; -import io.rsocket.internal.SynchronizedIntObjectHashMap; import org.junit.Test; public class StreamIdSupplierTest { @Test public void testClientSequence() { - IntObjectMap map = new SynchronizedIntObjectHashMap<>(); + IntObjectMap map = new IntObjectHashMap<>(); StreamIdSupplier s = StreamIdSupplier.clientSupplier(); assertEquals(1, s.nextStreamId(map)); assertEquals(3, s.nextStreamId(map)); @@ -36,7 +36,7 @@ public void testClientSequence() { @Test public void testServerSequence() { - IntObjectMap map = new SynchronizedIntObjectHashMap<>(); + IntObjectMap map = new IntObjectHashMap<>(); StreamIdSupplier s = StreamIdSupplier.serverSupplier(); assertEquals(2, s.nextStreamId(map)); assertEquals(4, s.nextStreamId(map)); @@ -45,7 +45,7 @@ public void testServerSequence() { @Test public void testClientIsValid() { - IntObjectMap map = new SynchronizedIntObjectHashMap<>(); + IntObjectMap map = new IntObjectHashMap<>(); StreamIdSupplier s = StreamIdSupplier.clientSupplier(); assertFalse(s.isBeforeOrCurrent(1)); @@ -68,7 +68,7 @@ public void testClientIsValid() { @Test public void testServerIsValid() { - IntObjectMap map = new SynchronizedIntObjectHashMap<>(); + IntObjectMap map = new IntObjectHashMap<>(); StreamIdSupplier s = StreamIdSupplier.serverSupplier(); assertFalse(s.isBeforeOrCurrent(2)); @@ -91,7 +91,7 @@ public void testServerIsValid() { @Test public void testWrap() { - IntObjectMap map = new SynchronizedIntObjectHashMap<>(); + IntObjectMap map = new IntObjectHashMap<>(); StreamIdSupplier s = new StreamIdSupplier(Integer.MAX_VALUE - 3); assertEquals(2147483646, s.nextStreamId(map)); @@ -107,7 +107,7 @@ public void testWrap() { @Test public void testSkipFound() { - IntObjectMap map = new SynchronizedIntObjectHashMap<>(); + IntObjectMap map = new IntObjectHashMap<>(); map.put(5, new Object()); map.put(9, new Object()); StreamIdSupplier s = StreamIdSupplier.clientSupplier(); diff --git a/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java b/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java new file mode 100644 index 000000000..9d7f5a3d2 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java @@ -0,0 +1,167 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.CharsetUtil; +import io.rsocket.Payload; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.util.ByteBufPayload; +import java.util.ArrayList; +import java.util.concurrent.ThreadLocalRandom; +import org.assertj.core.api.Assertions; +import reactor.core.Exceptions; +import reactor.util.annotation.Nullable; + +final class TestRequesterResponderSupport extends RequesterResponderSupport { + + static final String DATA_CONTENT = "testData"; + static final String METADATA_CONTENT = "testMetadata"; + + final Throwable error; + + TestRequesterResponderSupport( + @Nullable Throwable error, + StreamIdSupplier streamIdSupplier, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize) { + super( + mtu, + maxFrameLength, + maxInboundPayloadSize, + PayloadDecoder.ZERO_COPY, + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT), + streamIdSupplier); + this.error = error; + } + + static Payload genericPayload(LeaksTrackingByteBufAllocator allocator) { + ByteBuf data = allocator.buffer(); + data.writeCharSequence(DATA_CONTENT, CharsetUtil.UTF_8); + + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence(METADATA_CONTENT, CharsetUtil.UTF_8); + + return ByteBufPayload.create(data, metadata); + } + + static Payload randomPayload(LeaksTrackingByteBufAllocator allocator) { + boolean hasMetadata = ThreadLocalRandom.current().nextBoolean(); + ByteBuf metadataByteBuf; + if (hasMetadata) { + byte[] randomMetadata = new byte[ThreadLocalRandom.current().nextInt(0, 512)]; + ThreadLocalRandom.current().nextBytes(randomMetadata); + metadataByteBuf = allocator.buffer().writeBytes(randomMetadata); + } else { + metadataByteBuf = null; + } + byte[] randomData = new byte[ThreadLocalRandom.current().nextInt(512, 1024)]; + ThreadLocalRandom.current().nextBytes(randomData); + + ByteBuf dataByteBuf = allocator.buffer().writeBytes(randomData); + return ByteBufPayload.create(dataByteBuf, metadataByteBuf); + } + + static ArrayList prepareFragments( + LeaksTrackingByteBufAllocator allocator, int mtu, Payload payload) { + boolean hasMetadata = payload.hasMetadata(); + ByteBuf data = payload.sliceData(); + ByteBuf metadata = payload.sliceMetadata(); + ArrayList fragments = new ArrayList<>(); + + fragments.add( + FragmentationUtils.encodeFirstFragment( + allocator, mtu, FrameType.NEXT_COMPLETE, 1, hasMetadata, metadata, data)); + + while (metadata.isReadable() || data.isReadable()) { + fragments.add( + FragmentationUtils.encodeFollowsFragment(allocator, mtu, 1, true, metadata, data)); + } + + return fragments; + } + + @Override + public synchronized int getNextStreamId() { + int nextStreamId = super.getNextStreamId(); + + if (error != null) { + throw Exceptions.propagate(error); + } + + return nextStreamId; + } + + @Override + public synchronized int addAndGetNextStreamId(FrameHandler frameHandler) { + int nextStreamId = super.addAndGetNextStreamId(frameHandler); + + if (error != null) { + super.remove(nextStreamId, frameHandler); + throw Exceptions.propagate(error); + } + + return nextStreamId; + } + + public static TestRequesterResponderSupport client(@Nullable Throwable e) { + return client(0, FRAME_LENGTH_MASK, Integer.MAX_VALUE, e); + } + + public static TestRequesterResponderSupport client( + int mtu, int maxFrameLength, int maxInboundPayloadSize, @Nullable Throwable e) { + return new TestRequesterResponderSupport( + e, StreamIdSupplier.clientSupplier(), mtu, maxFrameLength, maxInboundPayloadSize); + } + + public static TestRequesterResponderSupport client( + int mtu, int maxFrameLength, int maxInboundPayloadSize) { + return client(mtu, maxFrameLength, maxInboundPayloadSize, null); + } + + public static TestRequesterResponderSupport client(int mtu, int maxFrameLength) { + return client(mtu, maxFrameLength, Integer.MAX_VALUE); + } + + public static TestRequesterResponderSupport client(int mtu) { + return client(mtu, FRAME_LENGTH_MASK); + } + + public static TestRequesterResponderSupport client() { + return client(0); + } + + public TestRequesterResponderSupport assertNoActiveStreams() { + Assertions.assertThat(activeStreams).isEmpty(); + return this; + } + + public TestRequesterResponderSupport assertHasStream(int i, FrameHandler stream) { + Assertions.assertThat(activeStreams).containsEntry(i, stream); + return this; + } + + @Override + public LeaksTrackingByteBufAllocator getAllocator() { + return (LeaksTrackingByteBufAllocator) super.getAllocator(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java index b3f596a37..fd05cb7da 100644 --- a/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java @@ -217,6 +217,7 @@ void fromCustomRSocketException() { assertThat(Exceptions.from(0, byteBuf)) .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", randomCode, "test-message") .isInstanceOf(IllegalArgumentException.class); + byteBuf.release(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java deleted file mode 100644 index 246fa1184..000000000 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.assertj.core.api.Assertions.assertThatNullPointerException; -import static org.mockito.Mockito.*; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.rsocket.DuplexConnection; -import io.rsocket.buffer.LeaksTrackingByteBufAllocator; -import io.rsocket.frame.*; -import java.util.concurrent.ThreadLocalRandom; -import org.junit.Assert; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -final class FragmentationDuplexConnectionTest { - private static byte[] data = new byte[1024]; - private static byte[] metadata = new byte[1024]; - - static { - ThreadLocalRandom.current().nextBytes(data); - ThreadLocalRandom.current().nextBytes(metadata); - } - - private final DuplexConnection delegate = mock(DuplexConnection.class, RETURNS_SMART_NULLS); - - { - Mockito.when(delegate.onClose()).thenReturn(Mono.never()); - } - - @SuppressWarnings("unchecked") - private final ArgumentCaptor> publishers = - ArgumentCaptor.forClass(Publisher.class); - - private LeaksTrackingByteBufAllocator allocator = - LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); - - @DisplayName("constructor throws IllegalArgumentException with negative maxFragmentLength") - @Test - void constructorInvalidMaxFragmentSize() { - assertThatIllegalArgumentException() - .isThrownBy( - () -> - new FragmentationDuplexConnection( - delegate, Integer.MIN_VALUE, Integer.MAX_VALUE, "")) - .withMessage("The smallest allowed mtu size is 64 bytes, provided: -2147483648"); - } - - @DisplayName("constructor throws IllegalArgumentException with negative maxFragmentLength") - @Test - void constructorMtuLessThanMin() { - assertThatIllegalArgumentException() - .isThrownBy(() -> new FragmentationDuplexConnection(delegate, 2, Integer.MAX_VALUE, "")) - .withMessage("The smallest allowed mtu size is 64 bytes, provided: 2"); - } - - @DisplayName("constructor throws NullPointerException with null delegate") - @Test - void constructorNullDelegate() { - assertThatNullPointerException() - .isThrownBy(() -> new FragmentationDuplexConnection(null, 64, Integer.MAX_VALUE, "")) - .withMessage("delegate must not be null"); - } - - @DisplayName("fragments data") - @Test - void sendData() { - ByteBuf encode = - RequestResponseFrameCodec.encode( - allocator, 1, false, Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(data)); - - when(delegate.onClose()).thenReturn(Mono.never()); - when(delegate.alloc()).thenReturn(allocator); - - new FragmentationDuplexConnection(delegate, 64, Integer.MAX_VALUE, "").sendOne(encode.retain()); - - verify(delegate).send(publishers.capture()); - - StepVerifier.create(Flux.from(publishers.getValue())) - .expectNextCount(17) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); - Assert.assertFalse(FrameHeaderCodec.hasFollows(byteBuf)); - }) - .verifyComplete(); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java deleted file mode 100644 index d27905f90..000000000 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java +++ /dev/null @@ -1,56 +0,0 @@ -package io.rsocket.fragmentation; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.frame.FrameHeaderCodec; -import io.rsocket.frame.FrameUtil; -import io.rsocket.frame.PayloadFrameCodec; -import io.rsocket.util.DefaultPayload; -import java.util.concurrent.ThreadLocalRandom; -import org.junit.Assert; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; - -public class FragmentationIntegrationTest { - private static byte[] data = new byte[128]; - private static byte[] metadata = new byte[128]; - - static { - ThreadLocalRandom.current().nextBytes(data); - ThreadLocalRandom.current().nextBytes(metadata); - } - - private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; - - @DisplayName("fragments and reassembles data") - @Test - void fragmentAndReassembleData() { - ByteBuf frame = - PayloadFrameCodec.encodeNextCompleteReleasingPayload( - allocator, 2, DefaultPayload.create(data)); - System.out.println(FrameUtil.toString(frame)); - - frame.retain(); - - Publisher fragments = - FrameFragmenter.fragmentFrame(allocator, 64, frame, FrameHeaderCodec.frameType(frame)); - - FrameReassembler reassembler = new FrameReassembler(allocator, Integer.MAX_VALUE); - - ByteBuf assembled = - Flux.from(fragments) - .doOnNext(byteBuf -> System.out.println(FrameUtil.toString(byteBuf))) - .handle(reassembler::reassembleFrame) - .blockLast(); - - System.out.println("assembled"); - String s = FrameUtil.toString(assembled); - System.out.println(s); - - Assert.assertEquals(FrameHeaderCodec.frameType(frame), FrameHeaderCodec.frameType(assembled)); - Assert.assertEquals(frame.readableBytes(), assembled.readableBytes()); - Assert.assertEquals(PayloadFrameCodec.data(frame), PayloadFrameCodec.data(assembled)); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java deleted file mode 100644 index 4548e4696..000000000 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java +++ /dev/null @@ -1,350 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.rsocket.frame.*; -import java.util.concurrent.ThreadLocalRandom; -import org.junit.Assert; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.test.StepVerifier; - -final class FrameFragmenterTest { - private static byte[] data = new byte[4096]; - private static byte[] metadata = new byte[4096]; - - static { - ThreadLocalRandom.current().nextBytes(data); - ThreadLocalRandom.current().nextBytes(metadata); - } - - private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; - - @Test - void testGettingData() { - ByteBuf rr = - RequestResponseFrameCodec.encode(allocator, 1, true, null, Unpooled.wrappedBuffer(data)); - ByteBuf fnf = - RequestFireAndForgetFrameCodec.encode( - allocator, 1, true, null, Unpooled.wrappedBuffer(data)); - ByteBuf rs = - RequestStreamFrameCodec.encode(allocator, 1, true, 1, null, Unpooled.wrappedBuffer(data)); - ByteBuf rc = - RequestChannelFrameCodec.encode( - allocator, 1, true, false, 1, null, Unpooled.wrappedBuffer(data)); - - ByteBuf data = FrameFragmenter.getData(rr, FrameType.REQUEST_RESPONSE); - Assert.assertEquals(data, Unpooled.wrappedBuffer(data)); - data.release(); - - data = FrameFragmenter.getData(fnf, FrameType.REQUEST_FNF); - Assert.assertEquals(data, Unpooled.wrappedBuffer(data)); - data.release(); - - data = FrameFragmenter.getData(rs, FrameType.REQUEST_STREAM); - Assert.assertEquals(data, Unpooled.wrappedBuffer(data)); - data.release(); - - data = FrameFragmenter.getData(rc, FrameType.REQUEST_CHANNEL); - Assert.assertEquals(data, Unpooled.wrappedBuffer(data)); - data.release(); - } - - @Test - void testGettingMetadata() { - ByteBuf rr = - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); - ByteBuf fnf = - RequestFireAndForgetFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); - ByteBuf rs = - RequestStreamFrameCodec.encode( - allocator, 1, true, 1, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); - ByteBuf rc = - RequestChannelFrameCodec.encode( - allocator, - 1, - true, - false, - 1, - Unpooled.wrappedBuffer(metadata), - Unpooled.wrappedBuffer(data)); - - ByteBuf data = FrameFragmenter.getMetadata(rr, FrameType.REQUEST_RESPONSE); - Assert.assertEquals(data, Unpooled.wrappedBuffer(metadata)); - data.release(); - - data = FrameFragmenter.getMetadata(fnf, FrameType.REQUEST_FNF); - Assert.assertEquals(data, Unpooled.wrappedBuffer(metadata)); - data.release(); - - data = FrameFragmenter.getMetadata(rs, FrameType.REQUEST_STREAM); - Assert.assertEquals(data, Unpooled.wrappedBuffer(metadata)); - data.release(); - - data = FrameFragmenter.getMetadata(rc, FrameType.REQUEST_CHANNEL); - Assert.assertEquals(data, Unpooled.wrappedBuffer(metadata)); - data.release(); - } - - @Test - void returnEmptBufferWhenNoMetadataPresent() { - ByteBuf rr = - RequestResponseFrameCodec.encode(allocator, 1, true, null, Unpooled.wrappedBuffer(data)); - - ByteBuf data = FrameFragmenter.getMetadata(rr, FrameType.REQUEST_RESPONSE); - Assert.assertEquals(data, Unpooled.EMPTY_BUFFER); - data.release(); - } - - @DisplayName("encode first frame") - @Test - void encodeFirstFrameWithData() { - ByteBuf rr = - RequestResponseFrameCodec.encode(allocator, 1, true, null, Unpooled.wrappedBuffer(data)); - - ByteBuf fragment = - FrameFragmenter.encodeFirstFragment( - allocator, - 256, - rr, - FrameType.REQUEST_RESPONSE, - 1, - Unpooled.EMPTY_BUFFER, - Unpooled.wrappedBuffer(data)); - - Assert.assertEquals(256, fragment.readableBytes()); - Assert.assertEquals(FrameType.REQUEST_RESPONSE, FrameHeaderCodec.frameType(fragment)); - Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); - Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); - - ByteBuf data = RequestResponseFrameCodec.data(fragment); - ByteBuf byteBuf = Unpooled.wrappedBuffer(this.data).readSlice(data.readableBytes()); - Assert.assertEquals(byteBuf, data); - - Assert.assertFalse(FrameHeaderCodec.hasMetadata(fragment)); - } - - @DisplayName("encode first channel frame") - @Test - void encodeFirstWithDataChannel() { - ByteBuf rc = - RequestChannelFrameCodec.encode( - allocator, 1, true, false, 10, null, Unpooled.wrappedBuffer(data)); - - ByteBuf fragment = - FrameFragmenter.encodeFirstFragment( - allocator, - 256, - rc, - FrameType.REQUEST_CHANNEL, - 1, - Unpooled.EMPTY_BUFFER, - Unpooled.wrappedBuffer(data)); - - Assert.assertEquals(256, fragment.readableBytes()); - Assert.assertEquals(FrameType.REQUEST_CHANNEL, FrameHeaderCodec.frameType(fragment)); - Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); - Assert.assertEquals(10, RequestChannelFrameCodec.initialRequestN(fragment)); - Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); - - ByteBuf data = RequestChannelFrameCodec.data(fragment); - ByteBuf byteBuf = Unpooled.wrappedBuffer(this.data).readSlice(data.readableBytes()); - Assert.assertEquals(byteBuf, data); - - Assert.assertFalse(FrameHeaderCodec.hasMetadata(fragment)); - } - - @DisplayName("encode first stream frame") - @Test - void encodeFirstWithDataStream() { - ByteBuf rc = - RequestStreamFrameCodec.encode(allocator, 1, true, 50, null, Unpooled.wrappedBuffer(data)); - - ByteBuf fragment = - FrameFragmenter.encodeFirstFragment( - allocator, - 256, - rc, - FrameType.REQUEST_STREAM, - 1, - Unpooled.EMPTY_BUFFER, - Unpooled.wrappedBuffer(data)); - - Assert.assertEquals(256, fragment.readableBytes()); - Assert.assertEquals(FrameType.REQUEST_STREAM, FrameHeaderCodec.frameType(fragment)); - Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); - Assert.assertEquals(50, RequestStreamFrameCodec.initialRequestN(fragment)); - Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); - - ByteBuf data = RequestStreamFrameCodec.data(fragment); - ByteBuf byteBuf = Unpooled.wrappedBuffer(this.data).readSlice(data.readableBytes()); - Assert.assertEquals(byteBuf, data); - - Assert.assertFalse(FrameHeaderCodec.hasMetadata(fragment)); - } - - @DisplayName("encode first frame with only metadata") - @Test - void encodeFirstFrameWithMetadata() { - ByteBuf rr = - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER); - - ByteBuf fragment = - FrameFragmenter.encodeFirstFragment( - allocator, - 256, - rr, - FrameType.REQUEST_RESPONSE, - 1, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER); - - Assert.assertEquals(256, fragment.readableBytes()); - Assert.assertEquals(FrameType.REQUEST_RESPONSE, FrameHeaderCodec.frameType(fragment)); - Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); - Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); - - ByteBuf data = RequestResponseFrameCodec.data(fragment); - Assert.assertEquals(data, Unpooled.EMPTY_BUFFER); - - Assert.assertTrue(FrameHeaderCodec.hasMetadata(fragment)); - } - - @DisplayName("encode first stream frame with data and metadata") - @Test - void encodeFirstWithDataAndMetadataStream() { - ByteBuf rc = - RequestStreamFrameCodec.encode( - allocator, 1, true, 50, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); - - ByteBuf fragment = - FrameFragmenter.encodeFirstFragment( - allocator, - 256, - rc, - FrameType.REQUEST_STREAM, - 1, - Unpooled.wrappedBuffer(metadata), - Unpooled.wrappedBuffer(data)); - - Assert.assertEquals(256, fragment.readableBytes()); - Assert.assertEquals(FrameType.REQUEST_STREAM, FrameHeaderCodec.frameType(fragment)); - Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); - Assert.assertEquals(50, RequestStreamFrameCodec.initialRequestN(fragment)); - Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); - - ByteBuf data = RequestStreamFrameCodec.data(fragment); - Assert.assertEquals(0, data.readableBytes()); - - ByteBuf metadata = RequestStreamFrameCodec.metadata(fragment); - ByteBuf byteBuf = Unpooled.wrappedBuffer(this.metadata).readSlice(metadata.readableBytes()); - Assert.assertEquals(byteBuf, metadata); - - Assert.assertTrue(FrameHeaderCodec.hasMetadata(fragment)); - } - - @DisplayName("fragments frame with only data") - @Test - void fragmentData() { - ByteBuf rr = - RequestResponseFrameCodec.encode(allocator, 1, true, null, Unpooled.wrappedBuffer(data)); - - Publisher fragments = - FrameFragmenter.fragmentFrame(allocator, 1024, rr, FrameType.REQUEST_RESPONSE); - - StepVerifier.create(Flux.from(fragments).doOnError(Throwable::printStackTrace)) - .expectNextCount(1) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); - Assert.assertEquals(1, FrameHeaderCodec.streamId(byteBuf)); - Assert.assertTrue(FrameHeaderCodec.hasFollows(byteBuf)); - }) - .expectNextCount(2) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); - Assert.assertFalse(FrameHeaderCodec.hasFollows(byteBuf)); - }) - .verifyComplete(); - } - - @DisplayName("fragments frame with only metadata") - @Test - void fragmentMetadata() { - ByteBuf rr = - RequestStreamFrameCodec.encode( - allocator, 1, true, 10, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER); - - Publisher fragments = - FrameFragmenter.fragmentFrame(allocator, 1024, rr, FrameType.REQUEST_STREAM); - - StepVerifier.create(Flux.from(fragments).doOnError(Throwable::printStackTrace)) - .expectNextCount(1) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); - Assert.assertEquals(1, FrameHeaderCodec.streamId(byteBuf)); - Assert.assertTrue(FrameHeaderCodec.hasFollows(byteBuf)); - }) - .expectNextCount(2) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); - Assert.assertFalse(FrameHeaderCodec.hasFollows(byteBuf)); - }) - .verifyComplete(); - } - - @DisplayName("fragments frame with data and metadata") - @Test - void fragmentDataAndMetadata() { - ByteBuf rr = - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); - - Publisher fragments = - FrameFragmenter.fragmentFrame(allocator, 1024, rr, FrameType.REQUEST_RESPONSE); - - StepVerifier.create(Flux.from(fragments).doOnError(Throwable::printStackTrace)) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.REQUEST_RESPONSE, FrameHeaderCodec.frameType(byteBuf)); - Assert.assertTrue(FrameHeaderCodec.hasFollows(byteBuf)); - }) - .expectNextCount(6) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); - Assert.assertTrue(FrameHeaderCodec.hasFollows(byteBuf)); - }) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); - Assert.assertFalse(FrameHeaderCodec.hasFollows(byteBuf)); - }) - .verifyComplete(); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java deleted file mode 100644 index 6f9762042..000000000 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java +++ /dev/null @@ -1,526 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.CompositeByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.util.ReferenceCountUtil; -import io.rsocket.frame.*; -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.ThreadLocalRandom; -import org.assertj.core.api.Assertions; -import org.junit.Assert; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import reactor.core.publisher.Flux; -import reactor.test.StepVerifier; - -final class FrameReassemblerTest { - private static byte[] data = new byte[1024]; - private static byte[] metadata = new byte[1024]; - - static { - ThreadLocalRandom.current().nextBytes(data); - ThreadLocalRandom.current().nextBytes(metadata); - } - - private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; - - @DisplayName("reassembles data") - @Test - void reassembleData() { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, null, Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); - - FrameReassembler reassembler = new FrameReassembler(allocator, Integer.MAX_VALUE); - - Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - - CompositeByteBuf data = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FrameReassemblerTest.data), - Unpooled.wrappedBuffer(FrameReassemblerTest.data), - Unpooled.wrappedBuffer(FrameReassemblerTest.data), - Unpooled.wrappedBuffer(FrameReassemblerTest.data), - Unpooled.wrappedBuffer(FrameReassemblerTest.data)); - - StepVerifier.create(assembled) - .assertNext( - byteBuf -> { - Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); - ReferenceCountUtil.safeRelease(byteBuf); - }) - .verifyComplete(); - ReferenceCountUtil.safeRelease(data); - } - - @DisplayName("pass through frames without follows") - @Test - void passthrough() { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, false, null, Unpooled.wrappedBuffer(data))); - - FrameReassembler reassembler = new FrameReassembler(allocator, Integer.MAX_VALUE); - - Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - - CompositeByteBuf data = - allocator - .compositeDirectBuffer() - .addComponents(true, Unpooled.wrappedBuffer(FrameReassemblerTest.data)); - - StepVerifier.create(assembled) - .assertNext( - byteBuf -> { - Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); - ReferenceCountUtil.safeRelease(byteBuf); - }) - .verifyComplete(); - ReferenceCountUtil.safeRelease(data); - } - - @DisplayName("reassembles metadata") - @Test - void reassembleMetadata() { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - false, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER)); - - FrameReassembler reassembler = new FrameReassembler(allocator, Integer.MAX_VALUE); - - Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - - CompositeByteBuf metadata = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata)); - - StepVerifier.create(assembled) - .assertNext( - byteBuf -> { - System.out.println(byteBuf.readableBytes()); - ByteBuf m = RequestResponseFrameCodec.metadata(byteBuf); - Assert.assertEquals(metadata, m); - }) - .verifyComplete(); - } - - @DisplayName("reassembles metadata request channel") - @Test - void reassembleMetadataChannel() { - List byteBufs = - Arrays.asList( - RequestChannelFrameCodec.encode( - allocator, - 1, - true, - false, - 100, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - false, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER)); - - FrameReassembler reassembler = new FrameReassembler(allocator, Integer.MAX_VALUE); - - Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - - CompositeByteBuf metadata = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata)); - - StepVerifier.create(assembled) - .assertNext( - byteBuf -> { - System.out.println(byteBuf.readableBytes()); - ByteBuf m = RequestChannelFrameCodec.metadata(byteBuf); - Assert.assertEquals(metadata, m); - Assert.assertEquals(100, RequestChannelFrameCodec.initialRequestN(byteBuf)); - ReferenceCountUtil.safeRelease(byteBuf); - }) - .verifyComplete(); - - ReferenceCountUtil.safeRelease(metadata); - } - - @DisplayName("reassembles metadata request stream") - @Test - void reassembleMetadataStream() { - List byteBufs = - Arrays.asList( - RequestStreamFrameCodec.encode( - allocator, 1, true, 250, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - false, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER)); - - FrameReassembler reassembler = new FrameReassembler(allocator, Integer.MAX_VALUE); - - Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - - CompositeByteBuf metadata = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata)); - - StepVerifier.create(assembled) - .assertNext( - byteBuf -> { - System.out.println(byteBuf.readableBytes()); - ByteBuf m = RequestStreamFrameCodec.metadata(byteBuf); - Assert.assertEquals(metadata, m); - Assert.assertEquals(250, RequestChannelFrameCodec.initialRequestN(byteBuf)); - ReferenceCountUtil.safeRelease(byteBuf); - }) - .verifyComplete(); - - ReferenceCountUtil.safeRelease(metadata); - } - - @DisplayName("reassembles metadata and data") - @Test - void reassembleMetadataAndData() { - - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); - - FrameReassembler reassembler = new FrameReassembler(allocator, Integer.MAX_VALUE); - - Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - - CompositeByteBuf data = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FrameReassemblerTest.data), - Unpooled.wrappedBuffer(FrameReassemblerTest.data)); - - CompositeByteBuf metadata = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata)); - - StepVerifier.create(assembled) - .assertNext( - byteBuf -> { - Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); - Assert.assertEquals(metadata, RequestResponseFrameCodec.metadata(byteBuf)); - }) - .verifyComplete(); - ReferenceCountUtil.safeRelease(data); - ReferenceCountUtil.safeRelease(metadata); - } - - @DisplayName("cancel removes inflight frames") - @Test - public void cancelBeforeAssembling() { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.wrappedBuffer(data))); - - FrameReassembler reassembler = new FrameReassembler(allocator, Integer.MAX_VALUE); - Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame).blockLast(); - - Assert.assertTrue(reassembler.headers.containsKey(1)); - Assert.assertTrue(reassembler.metadata.containsKey(1)); - Assert.assertTrue(reassembler.data.containsKey(1)); - - Flux.just(CancelFrameCodec.encode(allocator, 1)) - .handle(reassembler::reassembleFrame) - .blockLast(); - - Assert.assertFalse(reassembler.headers.containsKey(1)); - Assert.assertFalse(reassembler.metadata.containsKey(1)); - Assert.assertFalse(reassembler.data.containsKey(1)); - } - - @ParameterizedTest(name = "throws error if reassembling payload size exist {0}") - @ValueSource(ints = {64, 1024, 2048, 4096}) - public void errorTooBigPayload(int maxFrameLength) { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.wrappedBuffer(data))); - - FrameReassembler reassembler = new FrameReassembler(allocator, maxFrameLength); - - Assertions.assertThatThrownBy( - Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame)::blockLast) - .hasMessage("Reassembled payload went out of allowed size") - .isExactlyInstanceOf(IllegalStateException.class); - } - - @DisplayName("dispose should clean up maps") - @Test - public void dispose() { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.wrappedBuffer(data))); - - FrameReassembler reassembler = new FrameReassembler(allocator, Integer.MAX_VALUE); - Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame).blockLast(); - - Assert.assertTrue(reassembler.headers.containsKey(1)); - Assert.assertTrue(reassembler.metadata.containsKey(1)); - Assert.assertTrue(reassembler.data.containsKey(1)); - - reassembler.dispose(); - - Assert.assertFalse(reassembler.headers.containsKey(1)); - Assert.assertFalse(reassembler.metadata.containsKey(1)); - Assert.assertFalse(reassembler.data.containsKey(1)); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/ReassembleDuplexConnectionTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/ReassembleDuplexConnectionTest.java deleted file mode 100644 index 061c17ada..000000000 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/ReassembleDuplexConnectionTest.java +++ /dev/null @@ -1,334 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import static org.mockito.Mockito.RETURNS_SMART_NULLS; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.CompositeByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.util.ReferenceCountUtil; -import io.netty.util.ReferenceCounted; -import io.rsocket.DuplexConnection; -import io.rsocket.buffer.LeaksTrackingByteBufAllocator; -import io.rsocket.frame.CancelFrameCodec; -import io.rsocket.frame.FrameHeaderCodec; -import io.rsocket.frame.FrameType; -import io.rsocket.frame.PayloadFrameCodec; -import io.rsocket.frame.RequestResponseFrameCodec; -import java.time.Duration; -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.ThreadLocalRandom; -import org.assertj.core.api.Assertions; -import org.junit.Assert; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; -import reactor.test.StepVerifier; - -final class ReassembleDuplexConnectionTest { - private static byte[] data = new byte[1024]; - private static byte[] metadata = new byte[1024]; - - static { - ThreadLocalRandom.current().nextBytes(data); - ThreadLocalRandom.current().nextBytes(metadata); - } - - private final DuplexConnection delegate = mock(DuplexConnection.class, RETURNS_SMART_NULLS); - - private LeaksTrackingByteBufAllocator allocator = - LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); - - @DisplayName("reassembles data") - @Test - void reassembleData() { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, null, Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); - - CompositeByteBuf data = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data)); - - when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); - when(delegate.onClose()).thenReturn(Mono.never()); - when(delegate.alloc()).thenReturn(allocator); - - new ReassemblyDuplexConnection(delegate, Integer.MAX_VALUE) - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); - }) - .verifyComplete(); - } - - @DisplayName("reassembles metadata") - @Test - void reassembleMetadata() { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - false, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER)); - - CompositeByteBuf metadata = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata)); - - when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); - when(delegate.onClose()).thenReturn(Mono.never()); - when(delegate.alloc()).thenReturn(allocator); - - new ReassemblyDuplexConnection(delegate, Integer.MAX_VALUE) - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - System.out.println(byteBuf.readableBytes()); - ByteBuf m = RequestResponseFrameCodec.metadata(byteBuf); - Assert.assertEquals(metadata, m); - }) - .verifyComplete(); - } - - @DisplayName("reassembles metadata and data") - @Test - void reassembleMetadataAndData() { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); - - CompositeByteBuf data = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data)); - - CompositeByteBuf metadata = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata)); - - when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); - when(delegate.onClose()).thenReturn(Mono.never()); - when(delegate.alloc()).thenReturn(allocator); - - new ReassemblyDuplexConnection(delegate, Integer.MAX_VALUE) - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); - Assert.assertEquals(metadata, RequestResponseFrameCodec.metadata(byteBuf)); - }) - .verifyComplete(); - } - - @DisplayName("does not reassemble a non-fragment frame") - @Test - void reassembleNonFragment() { - ByteBuf encode = - RequestResponseFrameCodec.encode(allocator, 1, false, null, Unpooled.wrappedBuffer(data)); - - when(delegate.receive()).thenReturn(Flux.just(encode)); - when(delegate.onClose()).thenReturn(Mono.never()); - when(delegate.alloc()).thenReturn(allocator); - - new ReassemblyDuplexConnection(delegate, Integer.MAX_VALUE) - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - Assert.assertEquals( - Unpooled.wrappedBuffer(data), RequestResponseFrameCodec.data(byteBuf)); - }) - .verifyComplete(); - } - - @DisplayName("does not reassemble non fragmentable frame") - @Test - void reassembleNonFragmentableFrame() { - ByteBuf encode = CancelFrameCodec.encode(allocator, 2); - - when(delegate.receive()).thenReturn(Flux.just(encode)); - when(delegate.onClose()).thenReturn(Mono.never()); - when(delegate.alloc()).thenReturn(allocator); - - new ReassemblyDuplexConnection(delegate, Integer.MAX_VALUE) - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.CANCEL, FrameHeaderCodec.frameType(byteBuf)); - }) - .verifyComplete(); - } - - @ParameterizedTest(name = "throws error if reassembling payload size exist {0}") - @ValueSource(ints = {64, 1024, 2048, 4096}) - public void errorTooBigPayload(int maxFrameLength) { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.wrappedBuffer(data))); - - MonoProcessor onClose = MonoProcessor.create(); - - when(delegate.receive()) - .thenReturn( - Flux.fromIterable(byteBufs) - .doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::release)); - when(delegate.onClose()).thenReturn(onClose); - when(delegate.alloc()).thenReturn(allocator); - - new ReassemblyDuplexConnection(delegate, maxFrameLength) - .receive() - .doFinally(__ -> onClose.onComplete()) - .as(StepVerifier::create) - .expectErrorSatisfies( - t -> - Assertions.assertThat(t) - .hasMessage("Reassembled payload went out of allowed size") - .isExactlyInstanceOf(IllegalStateException.class)) - .verify(Duration.ofSeconds(1)); - - allocator.assertHasNoLeaks(); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java index 63300c718..75aa2a5b2 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java @@ -27,12 +27,17 @@ protected String fallbackToStringOf(Object object) { if (object instanceof ByteBuf) { try { String normalBufferString = object.toString(); - String prettyHexDump = ByteBufUtil.prettyHexDump((ByteBuf) object); - return new StringBuilder() - .append(normalBufferString) - .append("\n") - .append(prettyHexDump) - .toString(); + ByteBuf byteBuf = (ByteBuf) object; + if (byteBuf.readableBytes() <= 128) { + String prettyHexDump = ByteBufUtil.prettyHexDump(byteBuf); + return new StringBuilder() + .append(normalBufferString) + .append("\n") + .append(prettyHexDump) + .toString(); + } else { + return normalBufferString; + } } catch (IllegalReferenceCountException e) { // noops } diff --git a/rsocket-test/src/main/java/io/rsocket/test/FragmentationTransportTest.java b/rsocket-test/src/main/java/io/rsocket/test/FragmentationTransportTest.java new file mode 100644 index 000000000..214e7bc38 --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/FragmentationTransportTest.java @@ -0,0 +1,463 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test; + +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; +import io.rsocket.util.ByteBufPayload; +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.time.Duration; +import java.util.Objects; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.zip.GZIPInputStream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; + +public interface FragmentationTransportTest { + + String MOCK_DATA = "test-data"; + String MOCK_METADATA = "metadata"; + String LARGE_DATA = read("words.shakespeare.txt.gz"); + Payload LARGE_PAYLOAD = ByteBufPayload.create(LARGE_DATA, LARGE_DATA); + + static String read(String resourceName) { + + try (BufferedReader br = + new BufferedReader( + new InputStreamReader( + new GZIPInputStream( + Objects.requireNonNull( + FragmentationTransportTest.class + .getClassLoader() + .getResourceAsStream(resourceName)))))) { + + return br.lines().map(String::toLowerCase).collect(Collectors.joining("\n\r")); + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + + @AfterEach + default void close() { + getTransportPair().dispose(); + } + + default Payload createTestPayload(int metadataPresent) { + String metadata1; + + switch (metadataPresent % 5) { + case 0: + metadata1 = null; + break; + case 1: + metadata1 = ""; + break; + default: + metadata1 = MOCK_METADATA; + break; + } + String metadata = metadata1; + + return ByteBufPayload.create(MOCK_DATA, metadata); + } + + @DisplayName("makes 10 fireAndForget requests") + @Test + default void fireAndForget10() { + Flux.range(1, 10) + .flatMap(i -> getClient().fireAndForget(createTestPayload(i))) + .as(StepVerifier::create) + .expectNextCount(0) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 10 fireAndForget with Large Payload in Requests") + @Test + default void largePayloadFireAndForget10() { + Flux.range(1, 10) + .flatMap(i -> getClient().fireAndForget(LARGE_PAYLOAD.retain())) + .as(StepVerifier::create) + .expectNextCount(0) + .expectComplete() + .verify(getTimeout()); + } + + default RSocket getClient() { + return getTransportPair().getClient(); + } + + Duration getTimeout(); + + TransportPair getTransportPair(); + + @DisplayName("makes 10 metadataPush requests") + @Test + default void metadataPush10() { + Flux.range(1, 10) + .flatMap(i -> getClient().metadataPush(ByteBufPayload.create("", "test-metadata"))) + .as(StepVerifier::create) + .expectNextCount(0) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 10 metadataPush with Large Metadata in requests") + @Test + default void largePayloadMetadataPush10() { + Flux.range(1, 10) + .flatMap(i -> getClient().metadataPush(ByteBufPayload.create("", LARGE_DATA))) + .as(StepVerifier::create) + .expectNextCount(0) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 0 payloads") + @Test + default void requestChannel0() { + getClient() + .requestChannel(Flux.empty()) + .as(StepVerifier::create) + .expectNextCount(0) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(CancellationException.class) + .hasMessage("Empty Source")) + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 1 payloads") + @Test + default void requestChannel1() { + getClient() + .requestChannel(Mono.just(createTestPayload(0))) + .map(Payload::release) + .as(StepVerifier::create) + .expectNextCount(1) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 200,000 payloads") + @Test + default void requestChannel200_000() { + Flux payloads = Flux.range(0, 200_000).map(this::createTestPayload); + + getClient() + .requestChannel(payloads) + .map(Payload::release) + .as(StepVerifier::create) + .expectNextCount(200_000) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 100 large payloads") + @Test + default void largePayloadRequestChannel100() { + Flux payloads = Flux.range(0, 100).map(__ -> LARGE_PAYLOAD.retain()); + + getClient() + .requestChannel(payloads) + .map(Payload::release) + .as(StepVerifier::create) + .expectNextCount(100) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 20,000 payloads") + @Test + default void requestChannel20_000() { + Flux payloads = Flux.range(0, 20_000).map(metadataPresent -> createTestPayload(7)); + + getClient() + .requestChannel(payloads) + .doOnNext(this::assertChannelPayload) + .map(Payload::release) + .as(StepVerifier::create) + .expectNextCount(20_000) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 2,000,000 payloads") + @SlowTest + default void requestChannel2_000_000() { + Flux payloads = Flux.range(0, 2_000_000).map(this::createTestPayload); + + getClient() + .requestChannel(payloads) + .as(StepVerifier::create) + .expectNextCount(2_000_000) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 3 payloads") + @Test + default void requestChannel3() { + Flux payloads = Flux.range(0, 3).map(this::createTestPayload); + + getClient() + .requestChannel(payloads) + .map(Payload::release) + .as(StepVerifier::create) + .expectNextCount(3) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 512 payloads") + @Test + default void requestChannel512() { + Flux payloads = Flux.range(0, 512).map(this::createTestPayload); + + Flux.range(0, 1024) + .flatMap( + v -> Mono.fromRunnable(() -> check(payloads)).subscribeOn(Schedulers.elastic()), 12) + .blockLast(); + } + + default void check(Flux payloads) { + getClient() + .requestChannel(payloads) + .map(Payload::release) + .as(StepVerifier::create) + .expectNextCount(512) + .as("expected 512 items") + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestResponse request") + @Test + default void requestResponse1() { + getClient() + .requestResponse(createTestPayload(1)) + .doOnNext(this::assertPayload) + .map(Payload::release) + .as(StepVerifier::create) + .expectNextCount(1) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 10 requestResponse requests") + @Test + default void requestResponse10() { + Flux.range(1, 10) + .flatMap( + i -> + getClient() + .requestResponse(createTestPayload(i)) + .doOnNext(v -> assertPayload(v)) + .map(Payload::release)) + .as(StepVerifier::create) + .expectNextCount(10) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 100 requestResponse requests") + @Test + default void requestResponse100() { + Flux.range(1, 100) + .flatMap( + i -> + getClient() + .requestResponse(createTestPayload(i)) + .map( + payload -> { + String dataUtf8 = payload.getDataUtf8(); + payload.release(); + return dataUtf8; + })) + .as(StepVerifier::create) + .expectNextCount(100) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 100 requestResponse requests") + @Test + default void largePayloadRequestResponse100() { + Flux.range(1, 100) + .flatMap( + i -> + getClient() + .requestResponse(LARGE_PAYLOAD.retain()) + .map( + payload -> { + String dataUtf8 = payload.getDataUtf8(); + payload.release(); + return dataUtf8; + })) + .as(StepVerifier::create) + .expectNextCount(100) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 10,000 requestResponse requests") + @Test + default void requestResponse10_000() { + Flux.range(1, 10_000) + .flatMap( + i -> + getClient() + .requestResponse(createTestPayload(i)) + .map( + payload -> { + String dataUtf8 = payload.getDataUtf8(); + payload.release(); + return dataUtf8; + })) + .as(StepVerifier::create) + .expectNextCount(10_000) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestStream request and receives 10,000 responses") + @Test + default void requestStream10_000() { + getClient() + .requestStream(createTestPayload(3)) + .doOnNext(this::assertPayload) + .map(Payload::release) + .as(StepVerifier::create) + .expectNextCount(10_000) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestStream request and receives 5 responses") + @Test + default void requestStream5() { + getClient() + .requestStream(createTestPayload(3)) + .doOnNext(this::assertPayload) + .map(Payload::release) + .take(5) + .as(StepVerifier::create) + .expectNextCount(5) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestStream request and consumes result incrementally") + @Test + default void requestStreamDelayedRequestN() { + getClient() + .requestStream(createTestPayload(3)) + .map(Payload::release) + .take(10) + .as(StepVerifier::create) + .thenRequest(5) + .expectNextCount(5) + .thenRequest(5) + .expectNextCount(5) + .expectComplete() + .verify(getTimeout()); + } + + default void assertPayload(Payload p) { + TransportPair transportPair = getTransportPair(); + if (!transportPair.expectedPayloadData().equals(p.getDataUtf8()) + || !transportPair.expectedPayloadMetadata().equals(p.getMetadataUtf8())) { + throw new IllegalStateException("Unexpected payload"); + } + } + + default void assertChannelPayload(Payload p) { + if (!MOCK_DATA.equals(p.getDataUtf8()) || !MOCK_METADATA.equals(p.getMetadataUtf8())) { + throw new IllegalStateException("Unexpected payload"); + } + } + + final class TransportPair implements Disposable { + private static final String data = "hello world"; + private static final String metadata = "metadata"; + + private final RSocket client; + + private final S server; + + public TransportPair( + Supplier addressSupplier, + BiFunction clientTransportSupplier, + Function> serverTransportSupplier) { + + T address = addressSupplier.get(); + + server = + RSocketServer.create((setup, sendingSocket) -> Mono.just(new TestRSocket(data, metadata))) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .fragment(ThreadLocalRandom.current().nextInt(128, 512)) + .bind(serverTransportSupplier.apply(address)) + .block(); + + client = + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .keepAlive(Duration.ofMillis(Integer.MAX_VALUE), Duration.ofMillis(Integer.MAX_VALUE)) + .fragment(ThreadLocalRandom.current().nextInt(64, 256)) + .connect(clientTransportSupplier.apply(address, server)) + .doOnError(Throwable::printStackTrace) + .block(); + } + + @Override + public void dispose() { + server.dispose(); + } + + RSocket getClient() { + return client; + } + + public String expectedPayloadData() { + return data; + } + + public String expectedPayloadMetadata() { + return metadata; + } + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java index d48700445..d71c2ee21 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java @@ -18,7 +18,8 @@ import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.util.DefaultPayload; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.EmptyPayload; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -34,27 +35,30 @@ public TestRSocket(String data, String metadata) { @Override public Mono requestResponse(Payload payload) { - return Mono.just(DefaultPayload.create(data, metadata)); + payload.release(); + return Mono.just(ByteBufPayload.create(data, metadata)); } @Override public Flux requestStream(Payload payload) { - return Flux.range(1, 10_000).flatMap(l -> requestResponse(payload)); + payload.release(); + return Flux.range(1, 10_000).flatMap(l -> requestResponse(EmptyPayload.INSTANCE)); } @Override public Mono metadataPush(Payload payload) { + payload.release(); return Mono.empty(); } @Override public Mono fireAndForget(Payload payload) { + payload.release(); return Mono.empty(); } @Override public Flux requestChannel(Publisher payloads) { - // TODO is defensive copy neccesary? - return Flux.from(payloads).map(Payload::retain); + return Flux.from(payloads); } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java index fc059c7d1..d30d64100 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java @@ -27,6 +27,7 @@ import java.io.BufferedReader; import java.io.InputStreamReader; import java.time.Duration; +import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicLong; import java.util.function.BiFunction; import java.util.function.Function; @@ -155,7 +156,11 @@ default void requestChannel0() { .requestChannel(Flux.empty()) .as(StepVerifier::create) .expectNextCount(0) - .expectComplete() + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(CancellationException.class) + .hasMessage("Empty Source")) .verify(getTimeout()); } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpFragmentationTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpFragmentationTransportTest.java new file mode 100644 index 000000000..870f34221 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpFragmentationTransportTest.java @@ -0,0 +1,42 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.rsocket.test.FragmentationTransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; + +final class TcpFragmentationTransportTest implements FragmentationTransportTest { + + private final TransportPair transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server) -> TcpClientTransport.create(server.address()), + TcpServerTransport::create); + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(2); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java index 3b635b6d0..15f9ae3df 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java @@ -65,7 +65,7 @@ final class WebsocketSecureTransportTest implements TransportTest { @Override public Duration getTimeout() { - return Duration.ofMinutes(3); + return Duration.ofMinutes(5); } @Override