From 4a3d46098d06cfc5ac02d2cda662d418dce1ab9d Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Fri, 16 Oct 2020 00:07:59 +0300 Subject: [PATCH 1/3] provides handling of requestChannel with complete fla Signed-off-by: Oleh Dokuka --- .../src/main/java/io/rsocket/frame/FrameHeaderCodec.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java index fc146c935..57255dbe4 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java @@ -117,6 +117,13 @@ public static FrameType frameType(ByteBuf byteBuf) { } else { throw new IllegalArgumentException("Payload must set either or both of NEXT and COMPLETE."); } + } else if (FrameType.REQUEST_CHANNEL == result) { + final int flags = typeAndFlags & FRAME_FLAGS_MASK; + + boolean complete = FLAGS_C == (flags & FLAGS_C); + if (complete) { + result = FrameType.REQUEST_CHANNEL_COMPLETE; + } } byteBuf.resetReaderIndex(); From 8c7dfa7a8d7659f8e1733f74c1a1541358dc424d Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Mon, 12 Oct 2020 19:14:39 +0300 Subject: [PATCH 2/3] provides request intercepting api Signed-off-by: Oleh Dokuka --- .../core/FireAndForgetRequesterMono.java | 40 +++++ .../FireAndForgetResponderSubscriber.java | 54 ++++++- .../io/rsocket/core/RSocketConnector.java | 4 +- .../io/rsocket/core/RSocketRequester.java | 15 +- .../io/rsocket/core/RSocketResponder.java | 18 ++- .../java/io/rsocket/core/RSocketServer.java | 4 +- .../core/RequestChannelRequesterFlux.java | 103 ++++++++++-- .../RequestChannelResponderSubscriber.java | 151 ++++++++++++++++-- .../core/RequestResponseRequesterMono.java | 48 +++++- .../RequestResponseResponderSubscriber.java | 60 ++++++- .../core/RequestStreamRequesterFlux.java | 32 ++++ .../RequestStreamResponderSubscriber.java | 118 +++++++++++--- .../core/RequesterResponderSupport.java | 11 +- .../InitializingInterceptorRegistry.java | 29 +++- .../rsocket/plugins/InterceptorRegistry.java | 65 ++++++-- .../rsocket/plugins/RequestInterceptor.java | 42 +++++ .../SafeCompositeRequestInterceptor.java | 64 ++++++++ .../plugins/SafeRequestInterceptor.java | 53 ++++++ .../core/DefaultRSocketClientTests.java | 1 + .../io/rsocket/core/RSocketLeaseTest.java | 4 +- .../core/RSocketRequesterSubscribersTest.java | 1 + .../io/rsocket/core/RSocketRequesterTest.java | 1 + .../io/rsocket/core/RSocketResponderTest.java | 3 +- .../java/io/rsocket/core/RSocketTest.java | 4 +- .../io/rsocket/core/SetupRejectionTest.java | 2 + .../core/TestRequesterResponderSupport.java | 30 +++- .../plugins/TestRequestInterceptor.java | 81 ++++++++++ 27 files changed, 958 insertions(+), 80 deletions(-) create mode 100644 rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java create mode 100644 rsocket-core/src/main/java/io/rsocket/plugins/SafeCompositeRequestInterceptor.java create mode 100644 rsocket-core/src/main/java/io/rsocket/plugins/SafeRequestInterceptor.java create mode 100644 rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java diff --git a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java index e51c3e75f..cf5ec38b2 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java +++ b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java @@ -25,6 +25,7 @@ import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.frame.FrameType; +import io.rsocket.plugins.RequestInterceptor; import java.time.Duration; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import org.reactivestreams.Subscription; @@ -33,6 +34,7 @@ import reactor.core.Scannable; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; +import reactor.core.publisher.SignalType; import reactor.util.annotation.NonNull; import reactor.util.annotation.Nullable; @@ -51,6 +53,8 @@ final class FireAndForgetRequesterMono extends Mono implements Subscriptio final RequesterResponderSupport requesterResponderSupport; final DuplexConnection connection; + @Nullable final RequestInterceptor requestInterceptor; + FireAndForgetRequesterMono(Payload payload, RequesterResponderSupport requesterResponderSupport) { this.allocator = requesterResponderSupport.getAllocator(); this.payload = payload; @@ -58,6 +62,7 @@ final class FireAndForgetRequesterMono extends Mono implements Subscriptio this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); } @Override @@ -98,9 +103,19 @@ public void subscribe(CoreSubscriber actual) { return; } + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.sliceMetadata()); + } + try { if (isTerminated(this.state)) { p.release(); + + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.CANCEL); + } + return; } @@ -108,11 +123,21 @@ public void subscribe(CoreSubscriber actual) { streamId, FrameType.REQUEST_FNF, mtu, p, this.connection, this.allocator, true); } catch (Throwable e) { lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_ERROR); + } + actual.onError(e); return; } lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_COMPLETE); + } + actual.onComplete(); } @@ -162,6 +187,11 @@ public Void block() { throw Exceptions.propagate(t); } + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.sliceMetadata()); + } + try { sendReleasingPayload( streamId, @@ -173,10 +203,20 @@ public Void block() { true); } catch (Throwable e) { lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_ERROR); + } + throw Exceptions.propagate(e); } lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_COMPLETE); + } + return null; } diff --git a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java index 3a2363d47..81d36b810 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java @@ -22,11 +22,14 @@ import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.CoreSubscriber; import reactor.core.publisher.Mono; +import reactor.core.publisher.SignalType; +import reactor.util.annotation.Nullable; final class FireAndForgetResponderSubscriber implements CoreSubscriber, ResponderFrameHandler { @@ -42,6 +45,8 @@ final class FireAndForgetResponderSubscriber final RSocket handler; final int maxInboundPayloadSize; + @Nullable final RequestInterceptor requestInterceptor; + CompositeByteBuf frames; private FireAndForgetResponderSubscriber() { @@ -51,6 +56,19 @@ private FireAndForgetResponderSubscriber() { this.maxInboundPayloadSize = 0; this.requesterResponderSupport = null; this.handler = null; + this.requestInterceptor = null; + this.frames = null; + } + + FireAndForgetResponderSubscriber( + int streamId, RequesterResponderSupport requesterResponderSupport) { + this.streamId = streamId; + this.allocator = null; + this.payloadDecoder = null; + this.maxInboundPayloadSize = 0; + this.requesterResponderSupport = null; + this.handler = null; + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.frames = null; } @@ -65,6 +83,7 @@ private FireAndForgetResponderSubscriber() { this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); this.requesterResponderSupport = requesterResponderSupport; this.handler = handler; + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.frames = ReassemblyUtils.addFollowingFrame( @@ -81,11 +100,21 @@ public void onNext(Void voidVal) {} @Override public void onError(Throwable t) { + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(this.streamId, SignalType.ON_ERROR); + } + logger.debug("Dropped Outbound error", t); } @Override - public void onComplete() {} + public void onComplete() { + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(this.streamId, SignalType.ON_COMPLETE); + } + } @Override public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLastPayload) { @@ -95,11 +124,17 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas ReassemblyUtils.addFollowingFrame( frames, followingFrame, hasFollows, this.maxInboundPayloadSize); } catch (IllegalStateException t) { - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); this.frames = null; frames.release(); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } + logger.debug("Reassembly has failed", t); return; } @@ -114,6 +149,12 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas frames.release(); } catch (Throwable t) { ReferenceCountUtil.safeRelease(frames); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(this.streamId, SignalType.ON_ERROR); + } + logger.debug("Reassembly has failed", t); return; } @@ -127,9 +168,16 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas public final void handleCancel() { final CompositeByteBuf frames = this.frames; if (frames != null) { - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + this.frames = null; frames.release(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.CANCEL); + } } } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java index 05860476d..04e8e57e7 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java @@ -631,6 +631,7 @@ public Mono connect(Supplier transportSupplier) { (int) keepAliveInterval.toMillis(), (int) keepAliveMaxLifeTime.toMillis(), keepAliveHandler, + interceptors.initRequesterRequestInterceptor(), requesterLeaseHandler); RSocket wrappedRSocketRequester = @@ -669,7 +670,8 @@ public Mono connect(Supplier transportSupplier) { responderLeaseHandler, mtu, maxFrameLength, - maxInboundPayloadSize); + maxInboundPayloadSize, + interceptors.initResponderRequestInterceptor()); return wrappedRSocketRequester; }) diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index 044204225..c636770a8 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -33,6 +33,7 @@ import io.rsocket.keepalive.KeepAliveHandler; import io.rsocket.keepalive.KeepAliveSupport; import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.plugins.RequestInterceptor; import java.nio.channels.ClosedChannelException; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.function.Supplier; @@ -75,8 +76,16 @@ class RSocketRequester extends RequesterResponderSupport implements RSocket { int keepAliveTickPeriod, int keepAliveAckTimeout, @Nullable KeepAliveHandler keepAliveHandler, + @Nullable RequestInterceptor requestInterceptor, RequesterLeaseHandler leaseHandler) { - super(mtu, maxFrameLength, maxInboundPayloadSize, payloadDecoder, connection, streamIdSupplier); + super( + mtu, + maxFrameLength, + maxInboundPayloadSize, + payloadDecoder, + connection, + streamIdSupplier, + requestInterceptor); this.leaseHandler = leaseHandler; this.onClose = MonoProcessor.create(); @@ -319,6 +328,10 @@ private void terminate(Throwable e) { keepAliveFramesAcceptor.dispose(); } getDuplexConnection().dispose(); + final RequestInterceptor requestInterceptor = getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.dispose(); + } leaseHandler.dispose(); synchronized (this) { diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java index 3be97760b..d386ffc43 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -29,6 +29,7 @@ import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.plugins.RequestInterceptor; import java.nio.channels.ClosedChannelException; import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; @@ -39,6 +40,7 @@ import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; /** Responder side of RSocket. Receives {@link ByteBuf}s from a peer's {@link RSocketRequester} */ class RSocketResponder extends RequesterResponderSupport implements RSocket { @@ -64,8 +66,16 @@ class RSocketResponder extends RequesterResponderSupport implements RSocket { ResponderLeaseHandler leaseHandler, int mtu, int maxFrameLength, - int maxInboundPayloadSize) { - super(mtu, maxFrameLength, maxInboundPayloadSize, payloadDecoder, connection, null); + int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor) { + super( + mtu, + maxFrameLength, + maxInboundPayloadSize, + payloadDecoder, + connection, + null, + requestInterceptor); this.requestHandler = requestHandler; @@ -194,6 +204,10 @@ private void cleanup() { cleanUpSendingSubscriptions(); getDuplexConnection().dispose(); + final RequestInterceptor requestInterceptor = getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.dispose(); + } leaseHandlerDisposable.dispose(); requestHandler.dispose(); } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java index 258306cd2..a309fa1ad 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java @@ -420,6 +420,7 @@ private Mono acceptSetup( setupPayload.keepAliveInterval(), setupPayload.keepAliveMaxLifetime(), keepAliveHandler, + interceptors.initRequesterRequestInterceptor(), requesterLeaseHandler); RSocket wrappedRSocketRequester = interceptors.initRequester(rSocketRequester); @@ -451,7 +452,8 @@ private Mono acceptSetup( responderLeaseHandler, mtu, maxFrameLength, - maxInboundPayloadSize); + maxInboundPayloadSize, + interceptors.initResponderRequestInterceptor()); }) .doFinally(signalType -> setupPayload.release()) .then(); diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java index 722a7c2c5..0cc1cf1ff 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java @@ -34,6 +34,7 @@ import io.rsocket.frame.PayloadFrameCodec; import io.rsocket.frame.RequestNFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; import java.util.Objects; import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicLongFieldUpdater; @@ -42,6 +43,7 @@ import reactor.core.*; import reactor.core.publisher.Flux; import reactor.core.publisher.Operators; +import reactor.core.publisher.SignalType; import reactor.util.annotation.NonNull; import reactor.util.annotation.Nullable; import reactor.util.context.Context; @@ -59,6 +61,8 @@ final class RequestChannelRequesterFlux extends Flux final Publisher payloadsPublisher; + @Nullable final RequestInterceptor requestInterceptor; + volatile long state; static final AtomicLongFieldUpdater STATE = AtomicLongFieldUpdater.newUpdater(RequestChannelRequesterFlux.class, "state"); @@ -86,6 +90,7 @@ final class RequestChannelRequesterFlux extends Flux this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); } @Override @@ -203,6 +208,11 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { return; } + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, FrameType.REQUEST_CHANNEL, firstPayload.sliceMetadata()); + } + try { sendReleasingPayload( streamId, @@ -222,6 +232,11 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { this.outboundSubscription.cancel(); this.inboundDone = true; + + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } + this.inboundSubscriber.onError(e); return; } @@ -239,6 +254,9 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); connection.sendFrame(streamId, cancelFrame); + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.CANCEL); + } return; } @@ -268,16 +286,22 @@ final void sendFollowingPayload(Payload followingPayload) { if (!isValid(mtu, this.maxFrameLength, followingPayload, true)) { followingPayload.release(); - this.cancel(); - final IllegalArgumentException e = new IllegalArgumentException( String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + if (!this.tryCancel()) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + this.propagateErrorSafely(e); return; } } catch (IllegalReferenceCountException e) { - this.cancel(); + if (!this.tryCancel()) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } this.propagateErrorSafely(e); @@ -297,7 +321,10 @@ final void sendFollowingPayload(Payload followingPayload) { allocator, true); } catch (Throwable e) { - this.cancel(); + if (!this.tryCancel()) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } this.propagateErrorSafely(e); } @@ -309,6 +336,11 @@ void propagateErrorSafely(Throwable e) { if (!this.inboundDone) { synchronized (this) { if (!this.inboundDone) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(this.streamId, SignalType.ON_ERROR); + } + this.inboundDone = true; this.inboundSubscriber.onError(e); } else { @@ -322,16 +354,27 @@ void propagateErrorSafely(Throwable e) { @Override public final void cancel() { + if (!tryCancel()) { + return; + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(this.streamId, SignalType.CANCEL); + } + } + + boolean tryCancel() { long previousState = markTerminated(STATE, this); if (isTerminated(previousState)) { - return; + return false; } this.outboundSubscription.cancel(); if (!isFirstFrameSent(previousState)) { // no need to send anything, since we have not started a stream yet (no logical wire) - return; + return false; } final int streamId = this.streamId; @@ -341,6 +384,8 @@ public final void cancel() { final ByteBuf cancelFrame = CancelFrameCodec.encode(this.allocator, streamId); this.connection.sendFrame(streamId, cancelFrame); + + return true; } @Override @@ -376,6 +421,11 @@ public void onError(Throwable t) { // FIXME: must be scheduled on the connection event-loop to achieve serial // behaviour on the inbound subscriber synchronized (this) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_ERROR); + } + this.inboundDone = true; this.inboundSubscriber.onError(t); } @@ -405,12 +455,20 @@ public void onComplete() { final int streamId = this.streamId; - if (isInboundTerminated(previousState)) { + final boolean isInboundTerminated = isInboundTerminated(previousState); + if (isInboundTerminated) { this.requesterResponderSupport.remove(streamId, this); } final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); this.connection.sendFrame(streamId, completeFrame); + + if (isInboundTerminated) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_COMPLETE); + } + } } @Override @@ -428,6 +486,11 @@ public final void handleComplete() { if (isOutboundTerminated(previousState)) { this.requesterResponderSupport.remove(this.streamId, this); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_COMPLETE); + } } this.inboundSubscriber.onComplete(); @@ -443,7 +506,15 @@ public final void handleError(Throwable cause) { this.inboundDone = true; long previousState = markTerminated(STATE, this); - if (isTerminated(previousState) || isInboundTerminated(previousState)) { + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } else if (isInboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(this.streamId, SignalType.ON_ERROR); + } + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); return; } @@ -455,6 +526,12 @@ public final void handleError(Throwable cause) { this.requesterResponderSupport.remove(streamId, this); this.outboundSubscription.cancel(); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_ERROR); + } + this.inboundSubscriber.onError(cause); } @@ -486,11 +563,19 @@ public void handleCancel() { return; } - if (isInboundTerminated(previousState)) { + final boolean inboundTerminated = isInboundTerminated(previousState); + if (inboundTerminated) { this.requesterResponderSupport.remove(this.streamId, this); } this.outboundSubscription.cancel(); + + if (inboundTerminated) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.CANCEL); + } + } } @Override diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java index 67816407c..1a3a5152e 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java @@ -36,6 +36,7 @@ import io.rsocket.frame.PayloadFrameCodec; import io.rsocket.frame.RequestNFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; @@ -46,6 +47,8 @@ import reactor.core.Exceptions; import reactor.core.publisher.Flux; import reactor.core.publisher.Operators; +import reactor.core.publisher.SignalType; +import reactor.util.annotation.Nullable; import reactor.util.context.Context; final class RequestChannelResponderSubscriber extends Flux @@ -63,6 +66,8 @@ final class RequestChannelResponderSubscriber extends Flux final DuplexConnection connection; final long firstRequest; + @Nullable final RequestInterceptor requestInterceptor; + final RSocket handler; volatile long state; @@ -99,6 +104,7 @@ public RequestChannelResponderSubscriber( this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.handler = handler; this.firstRequest = firstRequestN; @@ -121,6 +127,7 @@ public RequestChannelResponderSubscriber( this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.firstRequest = firstRequestN; this.firstPayload = firstPayload; @@ -293,12 +300,20 @@ public void cancel() { final int streamId = this.streamId; - if (isOutboundTerminated(previousState)) { + final boolean isOutboundTerminated = isOutboundTerminated(previousState); + if (isOutboundTerminated) { this.requesterResponderSupport.remove(streamId, this); } final ByteBuf cancelFrame = CancelFrameCodec.encode(this.allocator, streamId); this.connection.sendFrame(streamId, cancelFrame); + + if (isOutboundTerminated) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_ERROR); + } + } } @Override @@ -320,10 +335,23 @@ public final void handleCancel() { this.firstPayload = null; firstPayload.release(); } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(this.streamId, SignalType.CANCEL); + } return; } - this.tryTerminate(true); + long previousState = this.tryTerminate(true); + if (isTerminated(previousState)) { + return; + } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(this.streamId, SignalType.CANCEL); + } } final long tryTerminate(boolean isFromInbound) { @@ -434,6 +462,11 @@ public final void handleError(Throwable t) { // reached it // needs for disconnected upstream and downstream case this.outboundSubscription.cancel(); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_ERROR); + } } @Override @@ -446,13 +479,21 @@ public void handleComplete() { long previousState = markInboundTerminated(STATE, this); - if (isOutboundTerminated(previousState)) { + final boolean isOutboundTerminated = isOutboundTerminated(previousState); + if (isOutboundTerminated) { this.requesterResponderSupport.remove(this.streamId, this); } if (isFirstFrameSent(previousState)) { this.inboundSubscriber.onComplete(); } + + if (isOutboundTerminated) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(this.streamId, SignalType.ON_COMPLETE); + } + } } @Override @@ -468,7 +509,15 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) payload = this.payloadDecoder.apply(frame); } catch (Throwable t) { long previousState = this.tryTerminate(true); - if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(this.streamId, SignalType.ON_ERROR); + } + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); return; } @@ -480,6 +529,10 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) ErrorFrameCodec.encode(this.allocator, streamId, new CanceledException(t.getMessage())); this.connection.sendFrame(streamId, errorFrame); + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_ERROR); + } return; } @@ -514,7 +567,15 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) } long previousState = this.tryTerminate(true); - if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + if (isTerminated(previousState)) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(this.streamId, SignalType.ON_ERROR); + } + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); return; } @@ -529,6 +590,11 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) new CanceledException("Failed to reassemble payload. Cause: " + e.getMessage())); this.connection.sendFrame(streamId, errorFrame); + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_ERROR); + } + return; } } @@ -549,7 +615,15 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) ReferenceCountUtil.safeRelease(frames); previousState = this.tryTerminate(true); - if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(this.streamId, SignalType.ON_ERROR); + } + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); return; } @@ -563,6 +637,11 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); this.connection.sendFrame(streamId, errorFrame); + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_ERROR); + } + return; } @@ -591,12 +670,6 @@ public void onNext(Payload p) { final DuplexConnection connection = this.connection; final ByteBufAllocator allocator = this.allocator; - if (p == null) { - final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(allocator, streamId); - connection.sendFrame(streamId, completeFrame); - return; - } - final int mtu = this.mtu; try { if (!isValid(mtu, this.maxFrameLength, p, false)) { @@ -605,7 +678,18 @@ public void onNext(Payload p) { // FIXME: must be scheduled on the connection event-loop to achieve serial // behaviour on the inbound subscriber long previousState = this.tryTerminate(false); - if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + if (isTerminated(previousState)) { + Operators.onErrorDropped( + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)), + this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_ERROR); + } + Operators.onErrorDropped( new IllegalArgumentException( String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)), @@ -620,6 +704,11 @@ public void onNext(Payload p) { new CanceledException( String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_ERROR); + } return; } } catch (IllegalReferenceCountException e) { @@ -627,7 +716,15 @@ public void onNext(Payload p) { // FIXME: must be scheduled on the connection event-loop to achieve serial // behaviour on the inbound subscriber long previousState = this.tryTerminate(false); - if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + if (isTerminated(previousState)) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_ERROR); + } + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); return; } @@ -638,6 +735,11 @@ public void onNext(Payload p) { streamId, new CanceledException("Failed to validate payload. Cause:" + e.getMessage())); connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_ERROR); + } return; } @@ -646,7 +748,11 @@ public void onNext(Payload p) { } catch (Throwable t) { // FIXME: must be scheduled on the connection event-loop to achieve serial // behaviour on the inbound subscriber - this.tryTerminate(false); + long previousState = this.tryTerminate(false); + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null && !isTerminated(previousState)) { + interceptor.onEnd(streamId, SignalType.ON_ERROR); + } } } @@ -705,6 +811,11 @@ public void onError(Throwable t) { final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_ERROR); + } } @Override @@ -722,12 +833,20 @@ public void onComplete() { final int streamId = this.streamId; - if (isInboundTerminated(previousState)) { + final boolean isInboundTerminated = isInboundTerminated(previousState); + if (isInboundTerminated) { this.requesterResponderSupport.remove(streamId, this); } final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); this.connection.sendFrame(streamId, completeFrame); + + if (isInboundTerminated) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onEnd(streamId, SignalType.ON_COMPLETE); + } + } } @Override diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java index 1706ece32..0c8c0a0ba 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java @@ -30,6 +30,7 @@ import io.rsocket.frame.CancelFrameCodec; import io.rsocket.frame.FrameType; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; @@ -37,6 +38,7 @@ import reactor.core.Scannable; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; +import reactor.core.publisher.SignalType; import reactor.util.annotation.NonNull; import reactor.util.annotation.Nullable; @@ -52,6 +54,8 @@ final class RequestResponseRequesterMono extends Mono final DuplexConnection connection; final PayloadDecoder payloadDecoder; + @Nullable final RequestInterceptor requestInterceptor; + volatile long state; static final AtomicLongFieldUpdater STATE = AtomicLongFieldUpdater.newUpdater(RequestResponseRequesterMono.class, "state"); @@ -72,6 +76,7 @@ final class RequestResponseRequesterMono extends Mono this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); } @Override @@ -141,6 +146,11 @@ void sendFirstPayload(Payload payload, long initialRequestN) { return; } + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, FrameType.REQUEST_RESPONSE, payload.sliceMetadata()); + } + try { sendReleasingPayload( streamId, FrameType.REQUEST_RESPONSE, this.mtu, payload, connection, allocator, true); @@ -150,6 +160,10 @@ void sendFirstPayload(Payload payload, long initialRequestN) { sm.remove(streamId, this); + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } + this.actual.onError(e); return; } @@ -164,6 +178,10 @@ void sendFirstPayload(Payload payload, long initialRequestN) { final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); connection.sendFrame(streamId, cancelFrame); + + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.CANCEL); + } } } @@ -181,6 +199,11 @@ public final void cancel() { ReassemblyUtils.synchronizedRelease(this, previousState); this.connection.sendFrame(streamId, CancelFrameCodec.encode(this.allocator, streamId)); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.CANCEL); + } } else if (!hasRequested(previousState)) { this.payload.release(); } @@ -201,10 +224,15 @@ public final void handlePayload(Payload value) { return; } - final CoreSubscriber a = this.actual; + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); - this.requesterResponderSupport.remove(this.streamId, this); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_COMPLETE); + } + final CoreSubscriber a = this.actual; a.onNext(value); a.onComplete(); } @@ -222,7 +250,13 @@ public final void handleComplete() { return; } - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_COMPLETE); + } this.actual.onComplete(); } @@ -244,7 +278,13 @@ public final void handleError(Throwable cause) { ReassemblyUtils.synchronizedRelease(this, previousState); - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } this.actual.onError(cause); } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java index f36211c7d..cdb139c67 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java @@ -32,6 +32,7 @@ import io.rsocket.frame.FrameType; import io.rsocket.frame.PayloadFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import org.reactivestreams.Subscription; import org.slf4j.Logger; @@ -39,6 +40,7 @@ import reactor.core.CoreSubscriber; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; +import reactor.core.publisher.SignalType; import reactor.util.annotation.Nullable; import reactor.util.context.Context; @@ -55,9 +57,10 @@ final class RequestResponseResponderSubscriber final int maxInboundPayloadSize; final RequesterResponderSupport requesterResponderSupport; final DuplexConnection connection; - final RSocket handler; + @Nullable final RequestInterceptor requestInterceptor; + boolean done; CompositeByteBuf frames; @@ -79,7 +82,9 @@ public RequestResponseResponderSubscriber( this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.handler = handler; + this.frames = ReassemblyUtils.addFollowingFrame( allocator.compositeBuffer(), firstFrame, true, maxInboundPayloadSize); @@ -94,6 +99,7 @@ public RequestResponseResponderSubscriber( this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.payloadDecoder = null; this.handler = null; @@ -137,6 +143,11 @@ public void onNext(@Nullable Payload p) { if (p == null) { final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(allocator, streamId); connection.sendFrame(streamId, completeFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_COMPLETE); + } return; } @@ -154,6 +165,11 @@ public void onNext(@Nullable Payload p) { new CanceledException( String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } return; } } catch (IllegalReferenceCountException e) { @@ -165,13 +181,28 @@ public void onNext(@Nullable Payload p) { streamId, new CanceledException("Failed to validate payload. Cause" + e.getMessage())); connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } return; } try { sendReleasingPayload(streamId, FrameType.NEXT_COMPLETE, mtu, p, connection, allocator, false); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_COMPLETE); + } } catch (Throwable ignored) { currentSubscription.cancel(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } } } @@ -197,6 +228,11 @@ public void onError(Throwable t) { final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } } @Override @@ -216,7 +252,8 @@ public void handleCancel() { // and fragmentation of the first frame was cancelled before S.lazySet(this, Operators.cancelledSubscription()); - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); final CompositeByteBuf frames = this.frames; if (frames != null) { @@ -224,6 +261,10 @@ public void handleCancel() { frames.release(); } + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.CANCEL); + } return; } @@ -234,6 +275,11 @@ public void handleCancel() { this.requesterResponderSupport.remove(this.streamId, this); currentSubscription.cancel(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } } @Override @@ -263,6 +309,11 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) streamId, new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } return; } @@ -289,6 +340,11 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) streamId, new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } return; } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java index a3107d4d6..bfc98e9ab 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java @@ -31,6 +31,7 @@ import io.rsocket.frame.FrameType; import io.rsocket.frame.RequestNFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; @@ -38,6 +39,7 @@ import reactor.core.Scannable; import reactor.core.publisher.Flux; import reactor.core.publisher.Operators; +import reactor.core.publisher.SignalType; import reactor.util.annotation.NonNull; import reactor.util.annotation.Nullable; @@ -53,6 +55,8 @@ final class RequestStreamRequesterFlux extends Flux final DuplexConnection connection; final PayloadDecoder payloadDecoder; + @Nullable final RequestInterceptor requestInterceptor; + volatile long state; static final AtomicLongFieldUpdater STATE = AtomicLongFieldUpdater.newUpdater(RequestStreamRequesterFlux.class, "state"); @@ -71,6 +75,7 @@ final class RequestStreamRequesterFlux extends Flux this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); } @Override @@ -149,6 +154,11 @@ void sendFirstPayload(Payload payload, long initialRequestN) { return; } + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, FrameType.REQUEST_STREAM, payload.sliceMetadata()); + } + try { sendReleasingPayload( streamId, @@ -165,6 +175,10 @@ void sendFirstPayload(Payload payload, long initialRequestN) { sm.remove(streamId, this); + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } + this.inboundSubscriber.onError(e); return; } @@ -180,6 +194,9 @@ void sendFirstPayload(Payload payload, long initialRequestN) { final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); connection.sendFrame(streamId, cancelFrame); + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.CANCEL); + } return; } @@ -215,6 +232,11 @@ public final void cancel() { ReassemblyUtils.synchronizedRelease(this, previousState); this.connection.sendFrame(streamId, CancelFrameCodec.encode(this.allocator, streamId)); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.CANCEL); + } } else if (!hasRequested(previousState)) { // no need to send anything, since the first request has not happened this.payload.release(); @@ -246,6 +268,11 @@ public final void handleComplete() { this.requesterResponderSupport.remove(this.streamId, this); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_COMPLETE); + } + this.inboundSubscriber.onComplete(); } @@ -268,6 +295,11 @@ public final void handleError(Throwable cause) { ReassemblyUtils.synchronizedRelease(this, previousState); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } + this.inboundSubscriber.onError(cause); } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java index 620638d9c..cde6e0d6c 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java @@ -32,6 +32,7 @@ import io.rsocket.frame.FrameType; import io.rsocket.frame.PayloadFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import org.reactivestreams.Subscription; import org.slf4j.Logger; @@ -39,6 +40,8 @@ import reactor.core.CoreSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Operators; +import reactor.core.publisher.SignalType; +import reactor.util.annotation.Nullable; import reactor.util.context.Context; final class RequestStreamResponderSubscriber @@ -56,6 +59,8 @@ final class RequestStreamResponderSubscriber final RequesterResponderSupport requesterResponderSupport; final DuplexConnection connection; + @Nullable final RequestInterceptor requestInterceptor; + final RSocket handler; volatile Subscription s; @@ -81,6 +86,7 @@ public RequestStreamResponderSubscriber( this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.handler = handler; this.frames = ReassemblyUtils.addFollowingFrame( @@ -97,6 +103,7 @@ public RequestStreamResponderSubscriber( this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.payloadDecoder = null; this.handler = null; @@ -123,20 +130,15 @@ public void onNext(Payload p) { final DuplexConnection sender = this.connection; final ByteBufAllocator allocator = this.allocator; - if (p == null) { - final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(allocator, streamId); - sender.sendFrame(streamId, completeFrame); - return; - } - final int mtu = this.mtu; try { if (!isValid(mtu, this.maxFrameLength, p, false)) { p.release(); - this.handleCancel(); + if (!this.tryTerminateOnError()) { + return; + } - this.done = true; final ByteBuf errorFrame = ErrorFrameCodec.encode( allocator, @@ -144,28 +146,65 @@ public void onNext(Payload p) { new CanceledException( String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); sender.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } return; } } catch (IllegalReferenceCountException e) { - this.handleCancel(); - this.done = true; + if (!this.tryTerminateOnError()) { + return; + } + final ByteBuf errorFrame = ErrorFrameCodec.encode( allocator, streamId, new CanceledException("Failed to validate payload. Cause" + e.getMessage())); sender.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } return; } try { sendReleasingPayload(streamId, FrameType.NEXT, mtu, p, sender, allocator, false); } catch (Throwable t) { - this.handleCancel(); - this.done = true; + if (!this.tryTerminateOnError()) { + return; + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } } } + boolean tryTerminateOnError() { + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription()) { + return false; + } + + this.done = true; + + if (!S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + return false; + } + + this.requesterResponderSupport.remove(this.streamId, this); + + currentSubscription.cancel(); + + return true; + } + @Override public void onError(Throwable t) { if (this.done) { @@ -186,11 +225,15 @@ public void onError(Throwable t) { } final int streamId = this.streamId; - this.requesterResponderSupport.remove(streamId, this); final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } } @Override @@ -206,11 +249,15 @@ public void onComplete() { } final int streamId = this.streamId; - this.requesterResponderSupport.remove(streamId, this); final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); this.connection.sendFrame(streamId, completeFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_COMPLETE); + } } @Override @@ -230,7 +277,8 @@ public final void handleCancel() { // and fragmentation of the first frame was cancelled before S.lazySet(this, Operators.cancelledSubscription()); - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); final CompositeByteBuf frames = this.frames; if (frames != null) { @@ -238,6 +286,10 @@ public final void handleCancel() { frames.release(); } + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.CANCEL); + } return; } @@ -245,9 +297,15 @@ public final void handleCancel() { return; } - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); currentSubscription.cancel(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.CANCEL); + } } @Override @@ -265,20 +323,26 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas // and fragmentation of the first frame was cancelled before S.lazySet(this, Operators.cancelledSubscription()); - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); this.frames = null; frames.release(); - logger.debug("Reassembly has failed", t); - // sends error frame from the responder side to tell that something went wrong final ByteBuf errorFrame = ErrorFrameCodec.encode( this.allocator, - this.streamId, + streamId, new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } + + logger.debug("Reassembly has failed", t); return; } @@ -292,19 +356,25 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas S.lazySet(this, Operators.cancelledSubscription()); this.done = true; - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); ReferenceCountUtil.safeRelease(frames); - logger.debug("Reassembly has failed", t); - // sends error frame from the responder side to tell that something went wrong final ByteBuf errorFrame = ErrorFrameCodec.encode( this.allocator, - this.streamId, + streamId, new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + } + + logger.debug("Reassembly has failed", t); return; } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java b/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java index e3f70cede..c24688802 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java @@ -5,6 +5,7 @@ import io.netty.util.collection.IntObjectMap; import io.rsocket.DuplexConnection; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; import reactor.util.annotation.Nullable; class RequesterResponderSupport { @@ -15,6 +16,7 @@ class RequesterResponderSupport { private final PayloadDecoder payloadDecoder; private final ByteBufAllocator allocator; private final DuplexConnection connection; + @Nullable private final RequestInterceptor requestInterceptor; @Nullable final StreamIdSupplier streamIdSupplier; final IntObjectMap activeStreams; @@ -25,7 +27,8 @@ public RequesterResponderSupport( int maxInboundPayloadSize, PayloadDecoder payloadDecoder, DuplexConnection connection, - @Nullable StreamIdSupplier streamIdSupplier) { + @Nullable StreamIdSupplier streamIdSupplier, + @Nullable RequestInterceptor requestInterceptor) { this.activeStreams = new IntObjectHashMap<>(); this.mtu = mtu; @@ -35,6 +38,7 @@ public RequesterResponderSupport( this.allocator = connection.alloc(); this.streamIdSupplier = streamIdSupplier; this.connection = connection; + this.requestInterceptor = requestInterceptor; } public int getMtu() { @@ -61,6 +65,11 @@ public DuplexConnection getDuplexConnection() { return connection; } + @Nullable + public RequestInterceptor getRequestInterceptor() { + return requestInterceptor; + } + /** * Issues next {@code streamId} * diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java index fc032847c..59fe6160b 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java @@ -18,6 +18,9 @@ import io.rsocket.DuplexConnection; import io.rsocket.RSocket; import io.rsocket.SocketAcceptor; +import java.util.List; +import java.util.function.Supplier; +import reactor.util.annotation.Nullable; /** * Extends {@link InterceptorRegistry} with methods for building a chain of registered interceptors. @@ -25,6 +28,30 @@ */ public class InitializingInterceptorRegistry extends InterceptorRegistry { + @Nullable + public RequestInterceptor initRequesterRequestInterceptor() { + return initRequestInterceptor(getRequesterRequestInterceptors()); + } + + @Nullable + public RequestInterceptor initResponderRequestInterceptor() { + return initRequestInterceptor(getResponderRequestInterceptors()); + } + + @Nullable + RequestInterceptor initRequestInterceptor( + List> interceptors) { + switch (interceptors.size()) { + case 0: + return null; + case 1: + return new SafeRequestInterceptor(interceptors.get(0).get()); + default: + return new SafeCompositeRequestInterceptor( + interceptors.stream().map(Supplier::get).toArray(RequestInterceptor[]::new)); + } + } + public DuplexConnection initConnection( DuplexConnectionInterceptor.Type type, DuplexConnection connection) { for (DuplexConnectionInterceptor interceptor : getConnectionInterceptors()) { @@ -34,7 +61,7 @@ public DuplexConnection initConnection( } public RSocket initRequester(RSocket rsocket) { - for (RSocketInterceptor interceptor : getRequesterInteceptors()) { + for (RSocketInterceptor interceptor : getRequesterInterceptors()) { rsocket = interceptor.apply(rsocket); } return rsocket; diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java index 427fa15ae..6fa621a2d 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.List; import java.util.function.Consumer; +import java.util.function.Supplier; /** * Provides support for registering interceptors at the following levels: @@ -30,16 +31,54 @@ * */ public class InterceptorRegistry { - private List requesterInteceptors = new ArrayList<>(); - private List responderInterceptors = new ArrayList<>(); + private List> requesterRequestInterceptors = + new ArrayList<>(); + private List> responderRequestInterceptors = + new ArrayList<>(); + private List requesterRSocketInterceptors = new ArrayList<>(); + private List responderRSocketInterceptors = new ArrayList<>(); private List socketAcceptorInterceptors = new ArrayList<>(); private List connectionInterceptors = new ArrayList<>(); + /** Add an {@link RequestInterceptor} that will hook into Requester RSocket requests' phases. */ + public InterceptorRegistry forRequesterRequests( + Supplier interceptor) { + requesterRequestInterceptors.add(interceptor); + return this; + } + + /** + * Variant of {@link #forRequesterRequests(Supplier)} with access to the list of existing + * registrations. + */ + public InterceptorRegistry forRequesterRequests( + Consumer>> consumer) { + consumer.accept(requesterRequestInterceptors); + return this; + } + + /** Add an {@link RequestInterceptor} that will hook into Requester RSocket requests' phases. */ + public InterceptorRegistry forResponderRequests( + Supplier interceptor) { + responderRequestInterceptors.add(interceptor); + return this; + } + + /** + * Variant of {@link #forResponderRequests(Supplier)} with access to the list of existing + * registrations. + */ + public InterceptorRegistry forResponderRequests( + Consumer>> consumer) { + consumer.accept(responderRequestInterceptors); + return this; + } + /** * Add an {@link RSocketInterceptor} that will decorate the RSocket used for performing requests. */ public InterceptorRegistry forRequester(RSocketInterceptor interceptor) { - requesterInteceptors.add(interceptor); + requesterRSocketInterceptors.add(interceptor); return this; } @@ -48,7 +87,7 @@ public InterceptorRegistry forRequester(RSocketInterceptor interceptor) { * registrations. */ public InterceptorRegistry forRequester(Consumer> consumer) { - consumer.accept(requesterInteceptors); + consumer.accept(requesterRSocketInterceptors); return this; } @@ -57,7 +96,7 @@ public InterceptorRegistry forRequester(Consumer> consu * requests. */ public InterceptorRegistry forResponder(RSocketInterceptor interceptor) { - responderInterceptors.add(interceptor); + responderRSocketInterceptors.add(interceptor); return this; } @@ -66,7 +105,7 @@ public InterceptorRegistry forResponder(RSocketInterceptor interceptor) { * registrations. */ public InterceptorRegistry forResponder(Consumer> consumer) { - consumer.accept(responderInterceptors); + consumer.accept(responderRSocketInterceptors); return this; } @@ -102,12 +141,20 @@ public InterceptorRegistry forConnection(Consumer getRequesterInteceptors() { - return requesterInteceptors; + List> getRequesterRequestInterceptors() { + return requesterRequestInterceptors; + } + + List> getResponderRequestInterceptors() { + return responderRequestInterceptors; + } + + List getRequesterInterceptors() { + return requesterRSocketInterceptors; } List getResponderInterceptors() { - return responderInterceptors; + return responderRSocketInterceptors; } List getConnectionInterceptors() { diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java new file mode 100644 index 000000000..38c03b94d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java @@ -0,0 +1,42 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import reactor.core.Disposable; +import reactor.core.publisher.SignalType; +import reactor.util.annotation.Nullable; + +/** Class used to track the RSocket requests lifecycles. */ +public interface RequestInterceptor extends Disposable { + + /** + * Method which is being invoked on successful acceptance and start of a request + * + * @param streamId used for the request + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param metadata provided in the request frame + */ + void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata); + + /** + * Method which is being invoked once a successfully accepted request is terminated + * + * @param streamId used by this request + * @param terminalSignal with which this finished has terminated. Must be one of the following + * signals + */ + void onEnd(int streamId, SignalType terminalSignal); + + /** + * Method which is being invoked on the request rejection. + * + * @param streamId used for the request + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param metadata provided in the request frame + */ + void onReject(int streamId, FrameType requestType, @Nullable ByteBuf metadata); +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/SafeCompositeRequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/SafeCompositeRequestInterceptor.java new file mode 100644 index 000000000..54af68329 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/SafeCompositeRequestInterceptor.java @@ -0,0 +1,64 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import reactor.core.publisher.Operators; +import reactor.core.publisher.SignalType; +import reactor.util.context.Context; + +public class SafeCompositeRequestInterceptor implements RequestInterceptor { + + final RequestInterceptor[] requestInterceptors; + + public SafeCompositeRequestInterceptor(RequestInterceptor[] requestInterceptors) { + this.requestInterceptors = requestInterceptors; + } + + @Override + public void dispose() { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + requestInterceptor.dispose(); + } + } + + @Override + public void onStart(int streamId, FrameType requestType, ByteBuf metadata) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onStart(streamId, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onEnd(int streamId, SignalType terminalSignal) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onEnd(streamId, terminalSignal); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onReject(int streamId, FrameType requestType, ByteBuf metadata) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onReject(streamId, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/SafeRequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/SafeRequestInterceptor.java new file mode 100644 index 000000000..27b4d909b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/SafeRequestInterceptor.java @@ -0,0 +1,53 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import reactor.core.publisher.Operators; +import reactor.core.publisher.SignalType; +import reactor.util.context.Context; + +public class SafeRequestInterceptor implements RequestInterceptor { + + final RequestInterceptor requestInterceptor; + + public SafeRequestInterceptor(RequestInterceptor requestInterceptor) { + this.requestInterceptor = requestInterceptor; + } + + @Override + public void dispose() { + requestInterceptor.dispose(); + } + + @Override + public boolean isDisposed() { + return requestInterceptor.isDisposed(); + } + + @Override + public void onStart(int streamId, FrameType requestType, ByteBuf metadata) { + try { + requestInterceptor.onStart(streamId, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onEnd(int streamId, SignalType terminalSignal) { + try { + requestInterceptor.onEnd(streamId, terminalSignal); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onReject(int streamId, FrameType requestType, ByteBuf metadata) { + try { + requestInterceptor.onReject(streamId, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java b/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java index d080b166d..573184853 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java +++ b/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java @@ -543,6 +543,7 @@ protected RSocketRequester newRSocket() { Integer.MAX_VALUE, Integer.MAX_VALUE, null, + null, RequesterLeaseHandler.None); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java index ae1282c1e..4bff64ec1 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java @@ -107,6 +107,7 @@ void setUp() { 0, 0, null, + null, requesterLeaseHandler); mockRSocketHandler = mock(RSocket.class); @@ -155,7 +156,8 @@ void setUp() { responderLeaseHandler, 0, FRAME_LENGTH_MASK, - Integer.MAX_VALUE); + Integer.MAX_VALUE, + null); } @Test diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java index fda6b61ee..570493faa 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java @@ -77,6 +77,7 @@ void setUp() { 0, 0, null, + null, RequesterLeaseHandler.None); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java index a0b3ef3f2..6640f8003 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -1375,6 +1375,7 @@ protected RSocketRequester newRSocket() { Integer.MAX_VALUE, Integer.MAX_VALUE, null, + null, RequesterLeaseHandler.None); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java index 0d0b0f093..414dbc04b 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java @@ -1216,7 +1216,8 @@ protected RSocketResponder newRSocket() { ResponderLeaseHandler.None, 0, maxFrameLength, - maxInboundPayloadSize); + maxInboundPayloadSize, + null); } private void sendRequest(int streamId, FrameType frameType) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java index 38745327e..f50b3ea42 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java @@ -569,7 +569,8 @@ public Flux requestChannel(Publisher payloads) { ResponderLeaseHandler.None, 0, FRAME_LENGTH_MASK, - Integer.MAX_VALUE); + Integer.MAX_VALUE, + null); crs = new RSocketRequester( @@ -582,6 +583,7 @@ public Flux requestChannel(Publisher payloads) { 0, 0, null, + null, RequesterLeaseHandler.None); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java index b96139fb5..202ea8279 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java @@ -61,6 +61,7 @@ void requesterStreamsTerminatedOnZeroErrorFrame() { 0, 0, null, + null, RequesterLeaseHandler.None); String errorMsg = "error"; @@ -98,6 +99,7 @@ void requesterNewStreamsTerminatedAfterZeroErrorFrame() { 0, 0, null, + null, RequesterLeaseHandler.None); conn.addToReceivedBuffer( diff --git a/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java b/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java index f81e8a610..332a4433e 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java +++ b/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java @@ -26,6 +26,7 @@ import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.FrameType; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.ByteBufPayload; import java.util.ArrayList; @@ -47,14 +48,16 @@ final class TestRequesterResponderSupport extends RequesterResponderSupport { DuplexConnection connection, int mtu, int maxFrameLength, - int maxInboundPayloadSize) { + int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor) { super( mtu, maxFrameLength, maxInboundPayloadSize, PayloadDecoder.ZERO_COPY, connection, - streamIdSupplier); + streamIdSupplier, + requestInterceptor); this.error = error; } @@ -182,14 +185,34 @@ public static TestRequesterResponderSupport client( mtu, maxFrameLength, maxInboundPayloadSize, + null, e); } + public static TestRequesterResponderSupport client( + TestDuplexConnection duplexConnection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize) { + return client(duplexConnection, mtu, maxFrameLength, maxInboundPayloadSize, null); + } + public static TestRequesterResponderSupport client( TestDuplexConnection duplexConnection, int mtu, int maxFrameLength, int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor) { + return client( + duplexConnection, mtu, maxFrameLength, maxInboundPayloadSize, requestInterceptor, null); + } + + public static TestRequesterResponderSupport client( + TestDuplexConnection duplexConnection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor, @Nullable Throwable e) { return new TestRequesterResponderSupport( e, @@ -197,7 +220,8 @@ public static TestRequesterResponderSupport client( duplexConnection, mtu, maxFrameLength, - maxInboundPayloadSize); + maxInboundPayloadSize, + requestInterceptor); } public static TestRequesterResponderSupport client( diff --git a/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java b/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java new file mode 100644 index 000000000..1174acedc --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java @@ -0,0 +1,81 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.jctools.queues.MpscUnboundedArrayQueue; +import java.util.Queue; +import org.assertj.core.api.Assertions; +import reactor.core.publisher.SignalType; + +public class TestRequestInterceptor implements RequestInterceptor { + + final Queue events = new MpscUnboundedArrayQueue<>(128); + + @Override + public void dispose() {} + + @Override + public void onStart(int streamId, FrameType requestType, ByteBuf metadata) { + events.add(new Event(EventType.ON_START, streamId, requestType, null)); + } + + @Override + public void onEnd(int streamId, SignalType terminalSignal) { + events.add(new Event(EventType.ON_END, streamId, null, terminalSignal)); + } + + @Override + public void onReject(int streamId, FrameType requestType, ByteBuf metadata) { + events.add(new Event(EventType.ON_REJECT, streamId, requestType, null)); + } + + public TestRequestInterceptor expectOnStart(int streamId, FrameType requestType) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_START) + .hasFieldOrPropertyWithValue("streamId", streamId) + .hasFieldOrPropertyWithValue("requestType", requestType); + + return this; + } + + public TestRequestInterceptor expectOnEnd(int streamId, SignalType signalType) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_END) + .hasFieldOrPropertyWithValue("streamId", streamId) + .hasFieldOrPropertyWithValue("signalType", signalType); + + return this; + } + + public TestRequestInterceptor expectNothing() { + final Event event = events.poll(); + + Assertions.assertThat(event).isNull(); + + return this; + } + + static final class Event { + final EventType eventType; + final int streamId; + final FrameType requestType; + final SignalType signalType; + + Event(EventType eventType, int streamId, FrameType requestType, SignalType signalType) { + this.eventType = eventType; + this.streamId = streamId; + this.requestType = requestType; + this.signalType = signalType; + } + } + + enum EventType { + ON_START, + ON_END, + ON_REJECT + } +} From 168450eb3c0aa2afe9ef829cbce0b96c88b51343 Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Tue, 13 Oct 2020 14:09:24 +0300 Subject: [PATCH 3/3] provides tests, fixes bugs and polish API Signed-off-by: Oleh Dokuka --- rsocket-core/build.gradle | 1 + .../core/FireAndForgetRequesterMono.java | 91 ++- .../FireAndForgetResponderSubscriber.java | 11 +- .../io/rsocket/core/RSocketConnector.java | 4 +- .../io/rsocket/core/RSocketRequester.java | 5 +- .../io/rsocket/core/RSocketResponder.java | 215 +++-- .../java/io/rsocket/core/RSocketServer.java | 4 +- .../core/RequestChannelRequesterFlux.java | 80 +- .../RequestChannelResponderSubscriber.java | 74 +- .../core/RequestResponseRequesterMono.java | 52 +- .../RequestResponseResponderSubscriber.java | 36 +- .../core/RequestStreamRequesterFlux.java | 54 +- .../RequestStreamResponderSubscriber.java | 35 +- .../core/RequesterResponderSupport.java | 6 +- .../io/rsocket/frame/FrameHeaderCodec.java | 7 - .../plugins/CompositeRequestInterceptor.java | 151 ++++ .../InitializingInterceptorRegistry.java | 24 +- .../rsocket/plugins/InterceptorRegistry.java | 51 +- .../rsocket/plugins/RequestInterceptor.java | 55 +- .../SafeCompositeRequestInterceptor.java | 64 -- .../plugins/SafeRequestInterceptor.java | 53 -- .../core/DefaultRSocketClientTests.java | 2 +- .../core/FireAndForgetRequesterMonoTest.java | 59 +- .../io/rsocket/core/RSocketLeaseTest.java | 220 ++++- .../core/RSocketRequesterSubscribersTest.java | 2 +- .../io/rsocket/core/RSocketRequesterTest.java | 2 +- .../io/rsocket/core/RSocketResponderTest.java | 54 +- .../java/io/rsocket/core/RSocketTest.java | 4 +- ...RequestChannelResponderSubscriberTest.java | 52 +- .../core/RequesterOperatorsRacingTest.java | 338 +++++--- .../core/ResponderOperatorsCommonTest.java | 65 +- .../io/rsocket/core/SetupRejectionTest.java | 4 +- .../core/TestRequesterResponderSupport.java | 27 +- .../plugins/RequestInterceptorTest.java | 754 ++++++++++++++++++ .../plugins/TestRequestInterceptor.java | 98 ++- 35 files changed, 2122 insertions(+), 632 deletions(-) create mode 100644 rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java delete mode 100644 rsocket-core/src/main/java/io/rsocket/plugins/SafeCompositeRequestInterceptor.java delete mode 100644 rsocket-core/src/main/java/io/rsocket/plugins/SafeRequestInterceptor.java create mode 100644 rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java diff --git a/rsocket-core/build.gradle b/rsocket-core/build.gradle index 41adbd7a8..53a896aea 100644 --- a/rsocket-core/build.gradle +++ b/rsocket-core/build.gradle @@ -29,6 +29,7 @@ dependencies { implementation 'org.slf4j:slf4j-api' + testImplementation (project(":rsocket-transport-local")) testImplementation 'io.projectreactor:reactor-test' testImplementation 'org.assertj:assertj-core' testImplementation 'org.junit.jupiter:junit-jupiter-api' diff --git a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java index cf5ec38b2..dec946bab 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java +++ b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java @@ -34,7 +34,6 @@ import reactor.core.Scannable; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; -import reactor.core.publisher.SignalType; import reactor.util.annotation.NonNull; import reactor.util.annotation.Nullable; @@ -69,8 +68,15 @@ final class FireAndForgetRequesterMono extends Mono implements Subscriptio public void subscribe(CoreSubscriber actual) { long previousState = markSubscribed(STATE, this); if (isSubscribedOrTerminated(previousState)) { - Operators.error( - actual, new IllegalStateException("FireAndForgetMono allows only a single Subscriber")); + final IllegalStateException e = + new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + Operators.error(actual, e); return; } @@ -81,14 +87,28 @@ public void subscribe(CoreSubscriber actual) { try { if (!isValid(mtu, this.maxFrameLength, p, false)) { lazyTerminate(STATE, this); - p.release(); - actual.onError( + + final IllegalArgumentException e = new IllegalArgumentException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + actual.onError(e); return; } } catch (IllegalReferenceCountException e) { lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + actual.onError(e); return; } @@ -98,14 +118,22 @@ public void subscribe(CoreSubscriber actual) { streamId = this.requesterResponderSupport.getNextStreamId(); } catch (Throwable t) { lazyTerminate(STATE, this); + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_FNF, p.metadata()); + } + p.release(); - actual.onError(Exceptions.unwrap(t)); + + actual.onError(ut); return; } final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.sliceMetadata()); + interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.metadata()); } try { @@ -113,7 +141,7 @@ public void subscribe(CoreSubscriber actual) { p.release(); if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.CANCEL); + interceptor.onCancel(streamId); } return; @@ -125,7 +153,7 @@ public void subscribe(CoreSubscriber actual) { lazyTerminate(STATE, this); if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_ERROR); + interceptor.onTerminate(streamId, e); } actual.onError(e); @@ -135,7 +163,7 @@ public void subscribe(CoreSubscriber actual) { lazyTerminate(STATE, this); if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_COMPLETE); + interceptor.onTerminate(streamId, null); } actual.onComplete(); @@ -162,19 +190,41 @@ public Void block(Duration m) { public Void block() { long previousState = markSubscribed(STATE, this); if (isSubscribedOrTerminated(previousState)) { - throw new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + final IllegalStateException e = + new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + throw e; } final Payload p = this.payload; try { if (!isValid(this.mtu, this.maxFrameLength, p, false)) { lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, p.metadata()); + } + p.release(); - throw new IllegalArgumentException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + + throw e; } } catch (IllegalReferenceCountException e) { lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + throw Exceptions.propagate(e); } @@ -183,13 +233,20 @@ public Void block() { streamId = this.requesterResponderSupport.getNextStreamId(); } catch (Throwable t) { lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(Exceptions.unwrap(t), FrameType.REQUEST_FNF, p.metadata()); + } + p.release(); + throw Exceptions.propagate(t); } final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.sliceMetadata()); + interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.metadata()); } try { @@ -205,7 +262,7 @@ public Void block() { lazyTerminate(STATE, this); if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_ERROR); + interceptor.onTerminate(streamId, e); } throw Exceptions.propagate(e); @@ -214,7 +271,7 @@ public Void block() { lazyTerminate(STATE, this); if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_COMPLETE); + interceptor.onTerminate(streamId, null); } return null; diff --git a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java index 81d36b810..889c98fde 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java @@ -28,7 +28,6 @@ import org.slf4j.LoggerFactory; import reactor.core.CoreSubscriber; import reactor.core.publisher.Mono; -import reactor.core.publisher.SignalType; import reactor.util.annotation.Nullable; final class FireAndForgetResponderSubscriber @@ -102,7 +101,7 @@ public void onNext(Void voidVal) {} public void onError(Throwable t) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(this.streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(this.streamId, t); } logger.debug("Dropped Outbound error", t); @@ -112,7 +111,7 @@ public void onError(Throwable t) { public void onComplete() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(this.streamId, SignalType.ON_COMPLETE); + requestInterceptor.onTerminate(this.streamId, null); } } @@ -132,7 +131,7 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(streamId, t); } logger.debug("Reassembly has failed", t); @@ -152,7 +151,7 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(this.streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(this.streamId, t); } logger.debug("Reassembly has failed", t); @@ -176,7 +175,7 @@ public final void handleCancel() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.CANCEL); + requestInterceptor.onCancel(streamId); } } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java index 04e8e57e7..342fd9480 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java @@ -631,7 +631,7 @@ public Mono connect(Supplier transportSupplier) { (int) keepAliveInterval.toMillis(), (int) keepAliveMaxLifeTime.toMillis(), keepAliveHandler, - interceptors.initRequesterRequestInterceptor(), + interceptors::initRequesterRequestInterceptor, requesterLeaseHandler); RSocket wrappedRSocketRequester = @@ -671,7 +671,7 @@ public Mono connect(Supplier transportSupplier) { mtu, maxFrameLength, maxInboundPayloadSize, - interceptors.initResponderRequestInterceptor()); + interceptors::initResponderRequestInterceptor); return wrappedRSocketRequester; }) diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index c636770a8..f51c14a6d 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -36,6 +36,7 @@ import io.rsocket.plugins.RequestInterceptor; import java.nio.channels.ClosedChannelException; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Function; import java.util.function.Supplier; import org.reactivestreams.Publisher; import org.slf4j.Logger; @@ -76,7 +77,7 @@ class RSocketRequester extends RequesterResponderSupport implements RSocket { int keepAliveTickPeriod, int keepAliveAckTimeout, @Nullable KeepAliveHandler keepAliveHandler, - @Nullable RequestInterceptor requestInterceptor, + Function requestInterceptorFunction, RequesterLeaseHandler leaseHandler) { super( mtu, @@ -85,7 +86,7 @@ class RSocketRequester extends RequesterResponderSupport implements RSocket { payloadDecoder, connection, streamIdSupplier, - requestInterceptor); + requestInterceptorFunction); this.leaseHandler = leaseHandler; this.onClose = MonoProcessor.create(); diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java index d386ffc43..b8f356493 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -25,7 +25,9 @@ import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.lease.ResponderLeaseHandler; @@ -33,6 +35,7 @@ import java.nio.channels.ClosedChannelException; import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Function; import java.util.function.Supplier; import org.reactivestreams.Publisher; import org.slf4j.Logger; @@ -40,7 +43,6 @@ import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.util.annotation.Nullable; /** Responder side of RSocket. Receives {@link ByteBuf}s from a peer's {@link RSocketRequester} */ class RSocketResponder extends RequesterResponderSupport implements RSocket { @@ -67,7 +69,7 @@ class RSocketResponder extends RequesterResponderSupport implements RSocket { int mtu, int maxFrameLength, int maxInboundPayloadSize, - @Nullable RequestInterceptor requestInterceptor) { + Function requestInterceptorFunction) { super( mtu, maxFrameLength, @@ -75,7 +77,7 @@ class RSocketResponder extends RequesterResponderSupport implements RSocket { payloadDecoder, connection, null, - requestInterceptor); + requestInterceptorFunction); this.requestHandler = requestHandler; @@ -103,7 +105,7 @@ private void tryTerminate(Supplier errorSupplier) { if (terminationError == null) { Throwable e = errorSupplier.get(); if (TERMINATION_ERROR.compareAndSet(this, null, e)) { - cleanup(); + doOnDispose(); } } } @@ -111,12 +113,7 @@ private void tryTerminate(Supplier errorSupplier) { @Override public Mono fireAndForget(Payload payload) { try { - if (leaseHandler.useLease()) { - return requestHandler.fireAndForget(payload); - } else { - payload.release(); - return Mono.error(leaseHandler.leaseError()); - } + return requestHandler.fireAndForget(payload); } catch (Throwable t) { return Mono.error(t); } @@ -125,12 +122,7 @@ public Mono fireAndForget(Payload payload) { @Override public Mono requestResponse(Payload payload) { try { - if (leaseHandler.useLease()) { - return requestHandler.requestResponse(payload); - } else { - payload.release(); - return Mono.error(leaseHandler.leaseError()); - } + return requestHandler.requestResponse(payload); } catch (Throwable t) { return Mono.error(t); } @@ -139,12 +131,7 @@ public Mono requestResponse(Payload payload) { @Override public Flux requestStream(Payload payload) { try { - if (leaseHandler.useLease()) { - return requestHandler.requestStream(payload); - } else { - payload.release(); - return Flux.error(leaseHandler.leaseError()); - } + return requestHandler.requestStream(payload); } catch (Throwable t) { return Flux.error(t); } @@ -153,24 +140,7 @@ public Flux requestStream(Payload payload) { @Override public Flux requestChannel(Publisher payloads) { try { - if (leaseHandler.useLease()) { - return requestHandler.requestChannel(payloads); - } else { - return Flux.error(leaseHandler.leaseError()); - } - } catch (Throwable t) { - return Flux.error(t); - } - } - - private Flux requestChannel(Payload payload, Publisher payloads) { - try { - if (leaseHandler.useLease()) { - return requestHandler.requestChannel(payloads); - } else { - payload.release(); - return Flux.error(leaseHandler.leaseError()); - } + return requestHandler.requestChannel(payloads); } catch (Throwable t) { return Flux.error(t); } @@ -200,7 +170,7 @@ public Mono onClose() { return getDuplexConnection().onClose(); } - private void cleanup() { + final void doOnDispose() { cleanUpSendingSubscriptions(); getDuplexConnection().dispose(); @@ -217,7 +187,7 @@ private synchronized void cleanUpSendingSubscriptions() { activeStreams.clear(); } - private void handleFrame(ByteBuf frame) { + final void handleFrame(ByteBuf frame) { try { int streamId = FrameHeaderCodec.streamId(frame); FrameHandler receiver; @@ -316,70 +286,149 @@ private void handleFrame(ByteBuf frame) { } } - private void handleFireAndForget(int streamId, ByteBuf frame) { - if (FrameHeaderCodec.hasFollows(frame)) { - FireAndForgetResponderSubscriber subscriber = - new FireAndForgetResponderSubscriber(streamId, frame, this, this); + final void handleFireAndForget(int streamId, ByteBuf frame) { + if (leaseHandler.useLease()) { + + if (FrameHeaderCodec.hasFollows(frame)) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_FNF, RequestFireAndForgetFrameCodec.metadata(frame)); + } - this.add(streamId, subscriber); + FireAndForgetResponderSubscriber subscriber = + new FireAndForgetResponderSubscriber(streamId, frame, this, this); + + this.add(streamId, subscriber); + } else { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_FNF, RequestFireAndForgetFrameCodec.metadata(frame)); + + fireAndForget(super.getPayloadDecoder().apply(frame)) + .subscribe(new FireAndForgetResponderSubscriber(streamId, this)); + } else { + fireAndForget(super.getPayloadDecoder().apply(frame)) + .subscribe(FireAndForgetResponderSubscriber.INSTANCE); + } + } } else { - fireAndForget(super.getPayloadDecoder().apply(frame)) - .subscribe(FireAndForgetResponderSubscriber.INSTANCE); + final RequestInterceptor requestTracker = this.getRequestInterceptor(); + if (requestTracker != null) { + requestTracker.onReject( + leaseHandler.leaseError(), + FrameType.REQUEST_FNF, + RequestFireAndForgetFrameCodec.metadata(frame)); + } } } - private void handleRequestResponse(int streamId, ByteBuf frame) { - if (FrameHeaderCodec.hasFollows(frame)) { - RequestResponseResponderSubscriber subscriber = - new RequestResponseResponderSubscriber(streamId, frame, this, this); + final void handleRequestResponse(int streamId, ByteBuf frame) { + if (leaseHandler.useLease()) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_RESPONSE, RequestResponseFrameCodec.metadata(frame)); + } - this.add(streamId, subscriber); - } else { - RequestResponseResponderSubscriber subscriber = - new RequestResponseResponderSubscriber(streamId, this); + if (FrameHeaderCodec.hasFollows(frame)) { + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber(streamId, frame, this, this); + + this.add(streamId, subscriber); + } else { + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber(streamId, this); - if (this.add(streamId, subscriber)) { - this.requestResponse(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + if (this.add(streamId, subscriber)) { + this.requestResponse(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + } + } + } else { + final Exception leaseError = leaseHandler.leaseError(); + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onReject( + leaseError, FrameType.REQUEST_RESPONSE, RequestResponseFrameCodec.metadata(frame)); } + sendLeaseRejection(streamId, leaseError); } } - private void handleStream(int streamId, ByteBuf frame, long initialRequestN) { - if (FrameHeaderCodec.hasFollows(frame)) { - RequestStreamResponderSubscriber subscriber = - new RequestStreamResponderSubscriber(streamId, initialRequestN, frame, this, this); + final void handleStream(int streamId, ByteBuf frame, long initialRequestN) { + if (leaseHandler.useLease()) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_STREAM, RequestStreamFrameCodec.metadata(frame)); + } - this.add(streamId, subscriber); - } else { - RequestStreamResponderSubscriber subscriber = - new RequestStreamResponderSubscriber(streamId, initialRequestN, this); + if (FrameHeaderCodec.hasFollows(frame)) { + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber(streamId, initialRequestN, frame, this, this); - if (this.add(streamId, subscriber)) { - this.requestStream(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + this.add(streamId, subscriber); + } else { + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber(streamId, initialRequestN, this); + + if (this.add(streamId, subscriber)) { + this.requestStream(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + } } + } else { + final Exception leaseError = leaseHandler.leaseError(); + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onReject( + leaseError, FrameType.REQUEST_STREAM, RequestStreamFrameCodec.metadata(frame)); + } + sendLeaseRejection(streamId, leaseError); } } - private void handleChannel(int streamId, ByteBuf frame, long initialRequestN, boolean complete) { - if (FrameHeaderCodec.hasFollows(frame)) { - RequestChannelResponderSubscriber subscriber = - new RequestChannelResponderSubscriber(streamId, initialRequestN, frame, this, this); + final void handleChannel(int streamId, ByteBuf frame, long initialRequestN, boolean complete) { + if (leaseHandler.useLease()) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_CHANNEL, RequestChannelFrameCodec.metadata(frame)); + } + + if (FrameHeaderCodec.hasFollows(frame)) { + RequestChannelResponderSubscriber subscriber = + new RequestChannelResponderSubscriber(streamId, initialRequestN, frame, this, this); - this.add(streamId, subscriber); - } else { - final Payload firstPayload = super.getPayloadDecoder().apply(frame); - RequestChannelResponderSubscriber subscriber = - new RequestChannelResponderSubscriber(streamId, initialRequestN, firstPayload, this); - - if (this.add(streamId, subscriber)) { - this.requestChannel(firstPayload, subscriber).subscribe(subscriber); - if (complete) { - subscriber.handleComplete(); + this.add(streamId, subscriber); + } else { + final Payload firstPayload = super.getPayloadDecoder().apply(frame); + RequestChannelResponderSubscriber subscriber = + new RequestChannelResponderSubscriber(streamId, initialRequestN, firstPayload, this); + + if (this.add(streamId, subscriber)) { + this.requestChannel(subscriber).subscribe(subscriber); + if (complete) { + subscriber.handleComplete(); + } } } + } else { + final Exception leaseError = leaseHandler.leaseError(); + final RequestInterceptor requestTracker = this.getRequestInterceptor(); + if (requestTracker != null) { + requestTracker.onReject( + leaseError, FrameType.REQUEST_CHANNEL, RequestChannelFrameCodec.metadata(frame)); + } + sendLeaseRejection(streamId, leaseError); } } + private void sendLeaseRejection(int streamId, Throwable leaseError) { + getDuplexConnection() + .sendFrame(streamId, ErrorFrameCodec.encode(getAllocator(), streamId, leaseError)); + } + private void handleMetadataPush(Mono result) { result.subscribe(MetadataPushResponderSubscriber.INSTANCE); } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java index a309fa1ad..b1c93f206 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java @@ -420,7 +420,7 @@ private Mono acceptSetup( setupPayload.keepAliveInterval(), setupPayload.keepAliveMaxLifetime(), keepAliveHandler, - interceptors.initRequesterRequestInterceptor(), + interceptors::initRequesterRequestInterceptor, requesterLeaseHandler); RSocket wrappedRSocketRequester = interceptors.initRequester(rSocketRequester); @@ -453,7 +453,7 @@ private Mono acceptSetup( mtu, maxFrameLength, maxInboundPayloadSize, - interceptors.initResponderRequestInterceptor()); + interceptors::initResponderRequestInterceptor); }) .doFinally(signalType -> setupPayload.release()) .then(); diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java index 0cc1cf1ff..8a57820c5 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java @@ -40,10 +40,11 @@ import java.util.concurrent.atomic.AtomicLongFieldUpdater; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; -import reactor.core.*; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; import reactor.core.publisher.Flux; import reactor.core.publisher.Operators; -import reactor.core.publisher.SignalType; import reactor.util.annotation.NonNull; import reactor.util.annotation.Nullable; import reactor.util.context.Context; @@ -99,8 +100,14 @@ public void subscribe(CoreSubscriber actual) { long previousState = markSubscribed(STATE, this); if (isSubscribedOrTerminated(previousState)) { - Operators.error( - actual, new IllegalStateException("RequestChannelFlux allows only a single Subscriber")); + final IllegalStateException e = + new IllegalStateException("RequestChannelFlux allows only a single Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_CHANNEL, null); + } + + Operators.error(actual, e); return; } @@ -168,13 +175,20 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { if (!isValid(mtu, this.maxFrameLength, firstPayload, true)) { lazyTerminate(STATE, this); - firstPayload.release(); this.outboundSubscription.cancel(); - this.inboundDone = true; - this.inboundSubscriber.onError( + final IllegalArgumentException e = new IllegalArgumentException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_CHANNEL, firstPayload.metadata()); + } + + firstPayload.release(); + + this.inboundDone = true; + this.inboundSubscriber.onError(e); return; } } catch (IllegalReferenceCountException e) { @@ -182,6 +196,11 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { this.outboundSubscription.cancel(); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_CHANNEL, null); + } + this.inboundDone = true; this.inboundSubscriber.onError(e); return; @@ -199,18 +218,25 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { this.inboundDone = true; final long previousState = markTerminated(STATE, this); - firstPayload.release(); this.outboundSubscription.cancel(); + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_CHANNEL, firstPayload.metadata()); + } + + firstPayload.release(); + if (!isTerminated(previousState)) { - this.inboundSubscriber.onError(Exceptions.unwrap(t)); + this.inboundSubscriber.onError(ut); } return; } final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onStart(streamId, FrameType.REQUEST_CHANNEL, firstPayload.sliceMetadata()); + requestInterceptor.onStart(streamId, FrameType.REQUEST_CHANNEL, firstPayload.metadata()); } try { @@ -225,7 +251,7 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { // TODO: Should be a different flag in case of the scalar // source or if we know in advance upstream is mono false); - } catch (Throwable e) { + } catch (Throwable t) { lazyTerminate(STATE, this); sm.remove(streamId, this); @@ -234,10 +260,10 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { this.inboundDone = true; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(streamId, t); } - this.inboundSubscriber.onError(e); + this.inboundSubscriber.onError(t); return; } @@ -255,7 +281,7 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { connection.sendFrame(streamId, cancelFrame); if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.CANCEL); + requestInterceptor.onCancel(streamId); } return; } @@ -330,7 +356,7 @@ final void sendFollowingPayload(Payload followingPayload) { } } - void propagateErrorSafely(Throwable e) { + void propagateErrorSafely(Throwable t) { // FIXME: must be scheduled on the connection event-loop to achieve serial // behaviour on the inbound subscriber if (!this.inboundDone) { @@ -338,17 +364,17 @@ void propagateErrorSafely(Throwable e) { if (!this.inboundDone) { final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onEnd(this.streamId, SignalType.ON_ERROR); + interceptor.onTerminate(this.streamId, t); } this.inboundDone = true; - this.inboundSubscriber.onError(e); + this.inboundSubscriber.onError(t); } else { - Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); } } } else { - Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); } } @@ -360,7 +386,7 @@ public final void cancel() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(this.streamId, SignalType.CANCEL); + requestInterceptor.onCancel(this.streamId); } } @@ -423,7 +449,7 @@ public void onError(Throwable t) { synchronized (this) { final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_ERROR); + interceptor.onTerminate(streamId, t); } this.inboundDone = true; @@ -466,7 +492,7 @@ public void onComplete() { if (isInboundTerminated) { final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_COMPLETE); + interceptor.onTerminate(streamId, null); } } } @@ -489,7 +515,7 @@ public final void handleComplete() { final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_COMPLETE); + interceptor.onTerminate(streamId, null); } } @@ -512,7 +538,7 @@ public final void handleError(Throwable cause) { } else if (isInboundTerminated(previousState)) { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onEnd(this.streamId, SignalType.ON_ERROR); + interceptor.onTerminate(this.streamId, cause); } Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); @@ -529,7 +555,7 @@ public final void handleError(Throwable cause) { final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_ERROR); + interceptor.onTerminate(streamId, cause); } this.inboundSubscriber.onError(cause); @@ -573,7 +599,7 @@ public void handleCancel() { if (inboundTerminated) { final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.CANCEL); + interceptor.onTerminate(this.streamId, null); } } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java index 1a3a5152e..9d4cd5f1e 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java @@ -47,7 +47,6 @@ import reactor.core.Exceptions; import reactor.core.publisher.Flux; import reactor.core.publisher.Operators; -import reactor.core.publisher.SignalType; import reactor.util.annotation.Nullable; import reactor.util.context.Context; @@ -311,7 +310,7 @@ public void cancel() { if (isOutboundTerminated) { final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_ERROR); + interceptor.onTerminate(streamId, null); } } } @@ -338,7 +337,7 @@ public final void handleCancel() { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onEnd(this.streamId, SignalType.CANCEL); + interceptor.onCancel(this.streamId); } return; } @@ -350,7 +349,7 @@ public final void handleCancel() { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onEnd(this.streamId, SignalType.CANCEL); + interceptor.onCancel(this.streamId); } } @@ -465,7 +464,7 @@ public final void handleError(Throwable t) { final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_ERROR); + interceptor.onTerminate(this.streamId, t); } } @@ -491,7 +490,7 @@ public void handleComplete() { if (isOutboundTerminated) { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onEnd(this.streamId, SignalType.ON_COMPLETE); + interceptor.onTerminate(this.streamId, null); } } } @@ -515,7 +514,7 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) } else if (isOutboundTerminated(previousState)) { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onEnd(this.streamId, SignalType.ON_ERROR); + interceptor.onTerminate(this.streamId, t); } Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); @@ -531,7 +530,7 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_ERROR); + interceptor.onTerminate(streamId, t); } return; } @@ -573,7 +572,7 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) } else if (isOutboundTerminated(previousState)) { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onEnd(this.streamId, SignalType.ON_ERROR); + interceptor.onTerminate(this.streamId, e); } Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); @@ -590,9 +589,9 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) new CanceledException("Failed to reassemble payload. Cause: " + e.getMessage())); this.connection.sendFrame(streamId, errorFrame); - final RequestInterceptor interceptor = requestInterceptor; + final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_ERROR); + interceptor.onTerminate(streamId, e); } return; @@ -621,7 +620,7 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) } else if (isOutboundTerminated(previousState)) { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onEnd(this.streamId, SignalType.ON_ERROR); + interceptor.onTerminate(this.streamId, t); } Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); @@ -639,7 +638,7 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_ERROR); + interceptor.onTerminate(streamId, t); } return; @@ -685,29 +684,28 @@ public void onNext(Payload p) { this.inboundSubscriber.currentContext()); return; } else if (isOutboundTerminated(previousState)) { + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_ERROR); + interceptor.onTerminate(streamId, e); } - Operators.onErrorDropped( - new IllegalArgumentException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)), - this.inboundSubscriber.currentContext()); + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); return; } - final ByteBuf errorFrame = - ErrorFrameCodec.encode( - allocator, - streamId, - new CanceledException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + final CanceledException e = + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, e); connection.sendFrame(streamId, errorFrame); final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_ERROR); + interceptor.onTerminate(streamId, e); } return; } @@ -722,7 +720,7 @@ public void onNext(Payload p) { } else if (isOutboundTerminated(previousState)) { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_ERROR); + interceptor.onTerminate(streamId, e); } Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); @@ -738,7 +736,7 @@ public void onNext(Payload p) { final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_ERROR); + interceptor.onTerminate(streamId, e); } return; } @@ -751,7 +749,7 @@ public void onNext(Payload p) { long previousState = this.tryTerminate(false); final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null && !isTerminated(previousState)) { - interceptor.onEnd(streamId, SignalType.ON_ERROR); + interceptor.onTerminate(streamId, t); } } } @@ -788,15 +786,13 @@ public void onError(Throwable t) { } } - if (!isFirstFrameSent(previousState)) { - if (!hasRequested(previousState)) { - final Payload firstPayload = this.firstPayload; - this.firstPayload = null; - firstPayload.release(); - } - } - - if (wasThrowableAdded && !isInboundTerminated(previousState)) { + if (!isSubscribed(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } else if (wasThrowableAdded + && isFirstFrameSent(previousState) + && !isInboundTerminated(previousState)) { Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); if (inboundError != TERMINATED) { // FIXME: must be scheduled on the connection event-loop to achieve serial @@ -814,7 +810,7 @@ public void onError(Throwable t) { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_ERROR); + interceptor.onTerminate(streamId, t); } } @@ -844,7 +840,7 @@ public void onComplete() { if (isInboundTerminated) { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onEnd(streamId, SignalType.ON_COMPLETE); + interceptor.onTerminate(streamId, null); } } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java index 0c8c0a0ba..f3c52f648 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java @@ -38,7 +38,6 @@ import reactor.core.Scannable; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; -import reactor.core.publisher.SignalType; import reactor.util.annotation.NonNull; import reactor.util.annotation.Nullable; @@ -84,8 +83,14 @@ public void subscribe(CoreSubscriber actual) { long previousState = markSubscribed(STATE, this); if (isSubscribedOrTerminated(previousState)) { - Operators.error( - actual, new IllegalStateException("RequestResponseMono allows only a single Subscriber")); + final IllegalStateException e = + new IllegalStateException("RequestResponseMono allows only a single " + "Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_RESPONSE, null); + } + + Operators.error(actual, e); return; } @@ -93,15 +98,28 @@ public void subscribe(CoreSubscriber actual) { try { if (!isValid(this.mtu, this.maxFrameLength, p, false)) { lazyTerminate(STATE, this); - Operators.error( - actual, + + final IllegalArgumentException e = new IllegalArgumentException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_RESPONSE, p.metadata()); + } + p.release(); + + Operators.error(actual, e); return; } } catch (IllegalReferenceCountException e) { lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_RESPONSE, null); + } + Operators.error(actual, e); return; } @@ -138,17 +156,23 @@ void sendFirstPayload(Payload payload, long initialRequestN) { this.done = true; final long previousState = markTerminated(STATE, this); + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_RESPONSE, payload.metadata()); + } + payload.release(); if (!isTerminated(previousState)) { - this.actual.onError(Exceptions.unwrap(t)); + this.actual.onError(ut); } return; } final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onStart(streamId, FrameType.REQUEST_RESPONSE, payload.sliceMetadata()); + requestInterceptor.onStart(streamId, FrameType.REQUEST_RESPONSE, payload.metadata()); } try { @@ -161,7 +185,7 @@ void sendFirstPayload(Payload payload, long initialRequestN) { sm.remove(streamId, this); if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(streamId, e); } this.actual.onError(e); @@ -180,7 +204,7 @@ void sendFirstPayload(Payload payload, long initialRequestN) { connection.sendFrame(streamId, cancelFrame); if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.CANCEL); + requestInterceptor.onCancel(streamId); } } } @@ -202,7 +226,7 @@ public final void cancel() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.CANCEL); + requestInterceptor.onCancel(streamId); } } else if (!hasRequested(previousState)) { this.payload.release(); @@ -229,7 +253,7 @@ public final void handlePayload(Payload value) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_COMPLETE); + requestInterceptor.onTerminate(streamId, null); } final CoreSubscriber a = this.actual; @@ -255,7 +279,7 @@ public final void handleComplete() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_COMPLETE); + requestInterceptor.onTerminate(streamId, null); } this.actual.onComplete(); @@ -283,7 +307,7 @@ public final void handleError(Throwable cause) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(streamId, cause); } this.actual.onError(cause); diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java index cdb139c67..648afff13 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java @@ -40,7 +40,6 @@ import reactor.core.CoreSubscriber; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; -import reactor.core.publisher.SignalType; import reactor.util.annotation.Nullable; import reactor.util.context.Context; @@ -146,7 +145,7 @@ public void onNext(@Nullable Payload p) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_COMPLETE); + requestInterceptor.onTerminate(streamId, null); } return; } @@ -158,17 +157,15 @@ public void onNext(@Nullable Payload p) { p.release(); - final ByteBuf errorFrame = - ErrorFrameCodec.encode( - allocator, - streamId, - new CanceledException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + final CanceledException e = + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, e); connection.sendFrame(streamId, errorFrame); final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(streamId, e); } return; } @@ -184,7 +181,7 @@ public void onNext(@Nullable Payload p) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(streamId, e); } return; } @@ -194,14 +191,14 @@ public void onNext(@Nullable Payload p) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_COMPLETE); + requestInterceptor.onTerminate(streamId, null); } - } catch (Throwable ignored) { + } catch (Throwable t) { currentSubscription.cancel(); final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(streamId, t); } } } @@ -231,7 +228,7 @@ public void onError(Throwable t) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(streamId, t); } } @@ -263,7 +260,7 @@ public void handleCancel() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.CANCEL); + requestInterceptor.onCancel(streamId); } return; } @@ -272,13 +269,14 @@ public void handleCancel() { return; } - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); currentSubscription.cancel(); final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onCancel(streamId); } } @@ -312,7 +310,7 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(streamId, t); } return; } @@ -343,7 +341,7 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(streamId, t); } return; } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java index bfc98e9ab..3608eaf52 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java @@ -39,7 +39,6 @@ import reactor.core.Scannable; import reactor.core.publisher.Flux; import reactor.core.publisher.Operators; -import reactor.core.publisher.SignalType; import reactor.util.annotation.NonNull; import reactor.util.annotation.Nullable; @@ -82,8 +81,14 @@ final class RequestStreamRequesterFlux extends Flux public void subscribe(CoreSubscriber actual) { long previousState = markSubscribed(STATE, this); if (isSubscribedOrTerminated(previousState)) { - Operators.error( - actual, new IllegalStateException("RequestStreamFlux allows only a single Subscriber")); + final IllegalStateException e = + new IllegalStateException("RequestStreamFlux allows only a single Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_STREAM, null); + } + + Operators.error(actual, e); return; } @@ -91,15 +96,28 @@ public void subscribe(CoreSubscriber actual) { try { if (!isValid(this.mtu, this.maxFrameLength, p, false)) { lazyTerminate(STATE, this); - Operators.error( - actual, + + final IllegalArgumentException e = new IllegalArgumentException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_STREAM, p.metadata()); + } + p.release(); + + Operators.error(actual, e); return; } } catch (IllegalReferenceCountException e) { lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_STREAM, null); + } + Operators.error(actual, e); return; } @@ -146,17 +164,23 @@ void sendFirstPayload(Payload payload, long initialRequestN) { this.done = true; final long previousState = markTerminated(STATE, this); + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_STREAM, payload.metadata()); + } + payload.release(); if (!isTerminated(previousState)) { - this.inboundSubscriber.onError(Exceptions.unwrap(t)); + this.inboundSubscriber.onError(ut); } return; } final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onStart(streamId, FrameType.REQUEST_STREAM, payload.sliceMetadata()); + requestInterceptor.onStart(streamId, FrameType.REQUEST_STREAM, payload.metadata()); } try { @@ -169,17 +193,17 @@ void sendFirstPayload(Payload payload, long initialRequestN) { connection, allocator, false); - } catch (Throwable e) { + } catch (Throwable t) { this.done = true; lazyTerminate(STATE, this); sm.remove(streamId, this); if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(streamId, t); } - this.inboundSubscriber.onError(e); + this.inboundSubscriber.onError(t); return; } @@ -195,7 +219,7 @@ void sendFirstPayload(Payload payload, long initialRequestN) { connection.sendFrame(streamId, cancelFrame); if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.CANCEL); + requestInterceptor.onCancel(streamId); } return; } @@ -235,7 +259,7 @@ public final void cancel() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.CANCEL); + requestInterceptor.onCancel(streamId); } } else if (!hasRequested(previousState)) { // no need to send anything, since the first request has not happened @@ -270,7 +294,7 @@ public final void handleComplete() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_COMPLETE); + requestInterceptor.onTerminate(streamId, null); } this.inboundSubscriber.onComplete(); @@ -297,7 +321,7 @@ public final void handleError(Throwable cause) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(streamId, cause); } this.inboundSubscriber.onError(cause); diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java index cde6e0d6c..6b06bc119 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java @@ -40,7 +40,6 @@ import reactor.core.CoreSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Operators; -import reactor.core.publisher.SignalType; import reactor.util.annotation.Nullable; import reactor.util.context.Context; @@ -139,17 +138,15 @@ public void onNext(Payload p) { return; } - final ByteBuf errorFrame = - ErrorFrameCodec.encode( - allocator, - streamId, - new CanceledException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + final CanceledException e = + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, e); sender.sendFrame(streamId, errorFrame); final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(streamId, e); } return; } @@ -167,7 +164,7 @@ public void onNext(Payload p) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(streamId, e); } return; } @@ -181,7 +178,7 @@ public void onNext(Payload p) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(streamId, t); } } } @@ -232,7 +229,7 @@ public void onError(Throwable t) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(streamId, t); } } @@ -256,7 +253,7 @@ public void onComplete() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_COMPLETE); + requestInterceptor.onTerminate(streamId, null); } } @@ -288,7 +285,7 @@ public final void handleCancel() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.CANCEL); + requestInterceptor.onCancel(streamId); } return; } @@ -304,7 +301,7 @@ public final void handleCancel() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.CANCEL); + requestInterceptor.onCancel(streamId); } } @@ -318,7 +315,7 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas try { ReassemblyUtils.addFollowingFrame( frames, followingFrame, hasFollows, this.maxInboundPayloadSize); - } catch (IllegalStateException t) { + } catch (IllegalStateException e) { // if subscription is null, it means that streams has not yet reassembled all the fragments // and fragmentation of the first frame was cancelled before S.lazySet(this, Operators.cancelledSubscription()); @@ -334,15 +331,15 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas ErrorFrameCodec.encode( this.allocator, streamId, - new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); + new CanceledException("Failed to reassemble payload. Cause: " + e.getMessage())); this.connection.sendFrame(streamId, errorFrame); final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(streamId, e); } - logger.debug("Reassembly has failed", t); + logger.debug("Reassembly has failed", e); return; } @@ -371,7 +368,7 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onEnd(streamId, SignalType.ON_ERROR); + requestInterceptor.onTerminate(streamId, t); } logger.debug("Reassembly has failed", t); diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java b/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java index c24688802..2272ceb5f 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java @@ -4,8 +4,10 @@ import io.netty.util.collection.IntObjectHashMap; import io.netty.util.collection.IntObjectMap; import io.rsocket.DuplexConnection; +import io.rsocket.RSocket; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.plugins.RequestInterceptor; +import java.util.function.Function; import reactor.util.annotation.Nullable; class RequesterResponderSupport { @@ -28,7 +30,7 @@ public RequesterResponderSupport( PayloadDecoder payloadDecoder, DuplexConnection connection, @Nullable StreamIdSupplier streamIdSupplier, - @Nullable RequestInterceptor requestInterceptor) { + Function requestInterceptorFunction) { this.activeStreams = new IntObjectHashMap<>(); this.mtu = mtu; @@ -38,7 +40,7 @@ public RequesterResponderSupport( this.allocator = connection.alloc(); this.streamIdSupplier = streamIdSupplier; this.connection = connection; - this.requestInterceptor = requestInterceptor; + this.requestInterceptor = requestInterceptorFunction.apply((RSocket) this); } public int getMtu() { diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java index 57255dbe4..fc146c935 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java @@ -117,13 +117,6 @@ public static FrameType frameType(ByteBuf byteBuf) { } else { throw new IllegalArgumentException("Payload must set either or both of NEXT and COMPLETE."); } - } else if (FrameType.REQUEST_CHANNEL == result) { - final int flags = typeAndFlags & FRAME_FLAGS_MASK; - - boolean complete = FLAGS_C == (flags & FLAGS_C); - if (complete) { - result = FrameType.REQUEST_CHANNEL_COMPLETE; - } } byteBuf.resetReaderIndex(); diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java new file mode 100644 index 000000000..b4e1a1ba3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java @@ -0,0 +1,151 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import java.util.List; +import java.util.function.Function; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +class CompositeRequestInterceptor implements RequestInterceptor { + + final RequestInterceptor[] requestInterceptors; + + public CompositeRequestInterceptor(RequestInterceptor[] requestInterceptors) { + this.requestInterceptors = requestInterceptors; + } + + @Override + public void dispose() { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + requestInterceptor.dispose(); + } + } + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onStart(streamId, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onTerminate(int streamId, @Nullable Throwable cause) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onTerminate(streamId, cause); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onCancel(int streamId) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onCancel(streamId); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onReject(rejectionReason, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Nullable + static RequestInterceptor create( + RSocket rSocket, List> interceptors) { + switch (interceptors.size()) { + case 0: + return null; + case 1: + return new SafeRequestInterceptor(interceptors.get(0).apply(rSocket)); + default: + return new CompositeRequestInterceptor( + interceptors.stream().map(f -> f.apply(rSocket)).toArray(RequestInterceptor[]::new)); + } + } + + static class SafeRequestInterceptor implements RequestInterceptor { + + final RequestInterceptor requestInterceptor; + + public SafeRequestInterceptor(RequestInterceptor requestInterceptor) { + this.requestInterceptor = requestInterceptor; + } + + @Override + public void dispose() { + requestInterceptor.dispose(); + } + + @Override + public boolean isDisposed() { + return requestInterceptor.isDisposed(); + } + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + try { + requestInterceptor.onStart(streamId, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onTerminate(int streamId, @Nullable Throwable cause) { + try { + requestInterceptor.onTerminate(streamId, cause); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onCancel(int streamId) { + try { + requestInterceptor.onCancel(streamId); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + try { + requestInterceptor.onReject(rejectionReason, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java index 59fe6160b..be0d8278f 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java @@ -18,8 +18,6 @@ import io.rsocket.DuplexConnection; import io.rsocket.RSocket; import io.rsocket.SocketAcceptor; -import java.util.List; -import java.util.function.Supplier; import reactor.util.annotation.Nullable; /** @@ -29,27 +27,13 @@ public class InitializingInterceptorRegistry extends InterceptorRegistry { @Nullable - public RequestInterceptor initRequesterRequestInterceptor() { - return initRequestInterceptor(getRequesterRequestInterceptors()); + public RequestInterceptor initRequesterRequestInterceptor(RSocket rSocketRequester) { + return CompositeRequestInterceptor.create(rSocketRequester, getRequesterRequestInterceptors()); } @Nullable - public RequestInterceptor initResponderRequestInterceptor() { - return initRequestInterceptor(getResponderRequestInterceptors()); - } - - @Nullable - RequestInterceptor initRequestInterceptor( - List> interceptors) { - switch (interceptors.size()) { - case 0: - return null; - case 1: - return new SafeRequestInterceptor(interceptors.get(0).get()); - default: - return new SafeCompositeRequestInterceptor( - interceptors.stream().map(Supplier::get).toArray(RequestInterceptor[]::new)); - } + public RequestInterceptor initResponderRequestInterceptor(RSocket rSocketResponder) { + return CompositeRequestInterceptor.create(rSocketResponder, getResponderRequestInterceptors()); } public DuplexConnection initConnection( diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java index 6fa621a2d..0ccc4cb92 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java @@ -15,10 +15,11 @@ */ package io.rsocket.plugins; +import io.rsocket.RSocket; import java.util.ArrayList; import java.util.List; import java.util.function.Consumer; -import java.util.function.Supplier; +import java.util.function.Function; /** * Provides support for registering interceptors at the following levels: @@ -31,46 +32,38 @@ * */ public class InterceptorRegistry { - private List> requesterRequestInterceptors = + private List> requesterRequestInterceptors = new ArrayList<>(); - private List> responderRequestInterceptors = + private List> responderRequestInterceptors = new ArrayList<>(); private List requesterRSocketInterceptors = new ArrayList<>(); private List responderRSocketInterceptors = new ArrayList<>(); private List socketAcceptorInterceptors = new ArrayList<>(); private List connectionInterceptors = new ArrayList<>(); - /** Add an {@link RequestInterceptor} that will hook into Requester RSocket requests' phases. */ - public InterceptorRegistry forRequesterRequests( - Supplier interceptor) { - requesterRequestInterceptors.add(interceptor); - return this; - } - /** - * Variant of {@link #forRequesterRequests(Supplier)} with access to the list of existing - * registrations. + * Add an {@link RequestInterceptor} that will hook into Requester RSocket requests' phases. + * + * @param interceptor a function which accepts an {@link RSocket} and returns a new {@link + * RequestInterceptor} + * @since 1.1 */ - public InterceptorRegistry forRequesterRequests( - Consumer>> consumer) { - consumer.accept(requesterRequestInterceptors); - return this; - } - - /** Add an {@link RequestInterceptor} that will hook into Requester RSocket requests' phases. */ - public InterceptorRegistry forResponderRequests( - Supplier interceptor) { - responderRequestInterceptors.add(interceptor); + public InterceptorRegistry forRequester( + Function interceptor) { + requesterRequestInterceptors.add(interceptor); return this; } /** - * Variant of {@link #forResponderRequests(Supplier)} with access to the list of existing - * registrations. + * Add an {@link RequestInterceptor} that will hook into Requester RSocket requests' phases. + * + * @param interceptor a function which accepts an {@link RSocket} and returns a new {@link + * RequestInterceptor} + * @since 1.1 */ - public InterceptorRegistry forResponderRequests( - Consumer>> consumer) { - consumer.accept(responderRequestInterceptors); + public InterceptorRegistry forResponder( + Function interceptor) { + responderRequestInterceptors.add(interceptor); return this; } @@ -141,11 +134,11 @@ public InterceptorRegistry forConnection(Consumer> getRequesterRequestInterceptors() { + List> getRequesterRequestInterceptors() { return requesterRequestInterceptors; } - List> getResponderRequestInterceptors() { + List> getResponderRequestInterceptors() { return responderRequestInterceptors; } diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java index 38c03b94d..5da850837 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java @@ -3,40 +3,71 @@ import io.netty.buffer.ByteBuf; import io.rsocket.frame.FrameType; import reactor.core.Disposable; -import reactor.core.publisher.SignalType; import reactor.util.annotation.Nullable; +import reactor.util.context.Context; -/** Class used to track the RSocket requests lifecycles. */ +/** + * Class used to track the RSocket requests lifecycles. The main difference and advantage of this + * interceptor compares to {@link RSocketInterceptor} is that it allows intercepting the initial and + * terminal phases on every individual request. + * + *

Note, if any of the invocations will rise a runtime exception, this exception will be + * caught and be propagated to {@link reactor.core.publisher.Operators#onErrorDropped(Throwable, + * Context)} + * + * @since 1.1 + */ public interface RequestInterceptor extends Disposable { /** - * Method which is being invoked on successful acceptance and start of a request + * Method which is being invoked on successful acceptance and start of a request. * * @param streamId used for the request * @param requestType of the request. Must be one of the following types {@link * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} - * @param metadata provided in the request frame + * @param metadata taken from the initial frame */ void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata); /** - * Method which is being invoked once a successfully accepted request is terminated + * Method which is being invoked once a successfully accepted request is terminated. This method + * can be invoked only after the {@link #onStart(int, FrameType, ByteBuf)} method. This method is + * exclusive with {@link #onCancel(int)}. * * @param streamId used by this request - * @param terminalSignal with which this finished has terminated. Must be one of the following - * signals + * @param t with which this finished has terminated. Must be one of the following signals */ - void onEnd(int streamId, SignalType terminalSignal); + void onTerminate(int streamId, @Nullable Throwable t); /** - * Method which is being invoked on the request rejection. + * Method which is being invoked once a successfully accepted request is cancelled. This method + * can be invoked only after the {@link #onStart(int, FrameType, ByteBuf)} method. This method is + * exclusive with {@link #onTerminate(int, Throwable)}. * - * @param streamId used for the request + * @param streamId used by this request + */ + void onCancel(int streamId); + + /** + * Method which is being invoked on the request rejection. This method is being called only if the + * actual request can not be started and is called instead of the {@link #onStart(int, FrameType, + * ByteBuf)} method. The reason for rejection can be one of the following: + * + *

+ * + *

    + *
  • No available {@link io.rsocket.lease.Lease} on the requester or the responder sides + *
  • Invalid {@link io.rsocket.Payload} size or format on the Requester side, so the request + * is being rejected before the actual streamId is generated + *
  • A second subscription on the ongoing Request + *
+ * + * @param rejectionReason exception which causes rejection of a particular request * @param requestType of the request. Must be one of the following types {@link * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} - * @param metadata provided in the request frame + * @param metadata taken from the initial frame */ - void onReject(int streamId, FrameType requestType, @Nullable ByteBuf metadata); + void onReject(Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata); } diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/SafeCompositeRequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/SafeCompositeRequestInterceptor.java deleted file mode 100644 index 54af68329..000000000 --- a/rsocket-core/src/main/java/io/rsocket/plugins/SafeCompositeRequestInterceptor.java +++ /dev/null @@ -1,64 +0,0 @@ -package io.rsocket.plugins; - -import io.netty.buffer.ByteBuf; -import io.rsocket.frame.FrameType; -import reactor.core.publisher.Operators; -import reactor.core.publisher.SignalType; -import reactor.util.context.Context; - -public class SafeCompositeRequestInterceptor implements RequestInterceptor { - - final RequestInterceptor[] requestInterceptors; - - public SafeCompositeRequestInterceptor(RequestInterceptor[] requestInterceptors) { - this.requestInterceptors = requestInterceptors; - } - - @Override - public void dispose() { - final RequestInterceptor[] requestInterceptors = this.requestInterceptors; - for (int i = 0; i < requestInterceptors.length; i++) { - final RequestInterceptor requestInterceptor = requestInterceptors[i]; - requestInterceptor.dispose(); - } - } - - @Override - public void onStart(int streamId, FrameType requestType, ByteBuf metadata) { - final RequestInterceptor[] requestInterceptors = this.requestInterceptors; - for (int i = 0; i < requestInterceptors.length; i++) { - final RequestInterceptor requestInterceptor = requestInterceptors[i]; - try { - requestInterceptor.onStart(streamId, requestType, metadata); - } catch (Throwable t) { - Operators.onErrorDropped(t, Context.empty()); - } - } - } - - @Override - public void onEnd(int streamId, SignalType terminalSignal) { - final RequestInterceptor[] requestInterceptors = this.requestInterceptors; - for (int i = 0; i < requestInterceptors.length; i++) { - final RequestInterceptor requestInterceptor = requestInterceptors[i]; - try { - requestInterceptor.onEnd(streamId, terminalSignal); - } catch (Throwable t) { - Operators.onErrorDropped(t, Context.empty()); - } - } - } - - @Override - public void onReject(int streamId, FrameType requestType, ByteBuf metadata) { - final RequestInterceptor[] requestInterceptors = this.requestInterceptors; - for (int i = 0; i < requestInterceptors.length; i++) { - final RequestInterceptor requestInterceptor = requestInterceptors[i]; - try { - requestInterceptor.onReject(streamId, requestType, metadata); - } catch (Throwable t) { - Operators.onErrorDropped(t, Context.empty()); - } - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/SafeRequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/SafeRequestInterceptor.java deleted file mode 100644 index 27b4d909b..000000000 --- a/rsocket-core/src/main/java/io/rsocket/plugins/SafeRequestInterceptor.java +++ /dev/null @@ -1,53 +0,0 @@ -package io.rsocket.plugins; - -import io.netty.buffer.ByteBuf; -import io.rsocket.frame.FrameType; -import reactor.core.publisher.Operators; -import reactor.core.publisher.SignalType; -import reactor.util.context.Context; - -public class SafeRequestInterceptor implements RequestInterceptor { - - final RequestInterceptor requestInterceptor; - - public SafeRequestInterceptor(RequestInterceptor requestInterceptor) { - this.requestInterceptor = requestInterceptor; - } - - @Override - public void dispose() { - requestInterceptor.dispose(); - } - - @Override - public boolean isDisposed() { - return requestInterceptor.isDisposed(); - } - - @Override - public void onStart(int streamId, FrameType requestType, ByteBuf metadata) { - try { - requestInterceptor.onStart(streamId, requestType, metadata); - } catch (Throwable t) { - Operators.onErrorDropped(t, Context.empty()); - } - } - - @Override - public void onEnd(int streamId, SignalType terminalSignal) { - try { - requestInterceptor.onEnd(streamId, terminalSignal); - } catch (Throwable t) { - Operators.onErrorDropped(t, Context.empty()); - } - } - - @Override - public void onReject(int streamId, FrameType requestType, ByteBuf metadata) { - try { - requestInterceptor.onReject(streamId, requestType, metadata); - } catch (Throwable t) { - Operators.onErrorDropped(t, Context.empty()); - } - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java b/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java index 573184853..b77e51537 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java +++ b/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java @@ -543,7 +543,7 @@ protected RSocketRequester newRSocket() { Integer.MAX_VALUE, Integer.MAX_VALUE, null, - null, + __ -> null, RequesterLeaseHandler.None); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java b/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java index 0857a2de8..f5422a4bf 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java @@ -14,6 +14,7 @@ import io.rsocket.Payload; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.FrameType; +import io.rsocket.plugins.TestRequestInterceptor; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.ByteBufPayload; import java.time.Duration; @@ -46,7 +47,9 @@ public static void setUp() { @ParameterizedTest @MethodSource("frameSent") public void frameShouldBeSentOnSubscription(Consumer monoConsumer) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); final Payload payload = genericPayload(activeStreams.getAllocator()); final FireAndForgetRequesterMono fireAndForgetRequesterMono = new FireAndForgetRequesterMono(payload, activeStreams); @@ -62,7 +65,6 @@ public void frameShouldBeSentOnSubscription(Consumer // should not add anything to map stateAssert.isTerminated(); activeStreams.assertNoActiveStreams(); - final ByteBuf frame = activeStreams.getDuplexConnection().awaitFrame(); FrameAssert.assertThat(frame) .isNotNull() @@ -79,6 +81,10 @@ public void frameShouldBeSentOnSubscription(Consumer Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); activeStreams.getAllocator().assertHasNoLeaks(); + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectNothing(); } /** @@ -189,7 +195,9 @@ static Stream> frameSent() { @MethodSource("shouldErrorOnIncorrectRefCntInGivenPayloadSource") public void shouldErrorOnIncorrectRefCntInGivenPayload( Consumer monoConsumer) { - final TestRequesterResponderSupport streamManager = TestRequesterResponderSupport.client(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); final TestDuplexConnection sender = streamManager.getDuplexConnection(); final Payload payload = ByteBufPayload.create(""); @@ -210,6 +218,9 @@ public void shouldErrorOnIncorrectRefCntInGivenPayload( Assertions.assertThat(sender.isEmpty()).isTrue(); allocator.assertHasNoLeaks(); + testRequestInterceptor + .expectOnReject(FrameType.REQUEST_FNF, new IllegalReferenceCountException("refCnt: 0")) + .expectNothing(); } static Stream> @@ -233,7 +244,9 @@ public void shouldErrorOnIncorrectRefCntInGivenPayload( @MethodSource("shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource") public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( Consumer monoConsumer) { - final TestRequesterResponderSupport streamManager = TestRequesterResponderSupport.client(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); final TestDuplexConnection sender = streamManager.getDuplexConnection(); @@ -260,6 +273,12 @@ public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( streamManager.assertNoActiveStreams(); Assertions.assertThat(sender.isEmpty()).isTrue(); allocator.assertHasNoLeaks(); + testRequestInterceptor + .expectOnReject( + FrameType.REQUEST_FNF, + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK))) + .expectNothing(); } static Stream> @@ -289,8 +308,10 @@ public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( @ParameterizedTest @MethodSource("shouldErrorIfNoAvailabilitySource") public void shouldErrorIfNoAvailability(Consumer monoConsumer) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final RuntimeException exception = new RuntimeException("test"); final TestRequesterResponderSupport streamManager = - TestRequesterResponderSupport.client(new RuntimeException("test")); + TestRequesterResponderSupport.client(exception, testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); final TestDuplexConnection sender = streamManager.getDuplexConnection(); final Payload payload = genericPayload(allocator); @@ -311,6 +332,7 @@ public void shouldErrorIfNoAvailability(Consumer mon streamManager.assertNoActiveStreams(); Assertions.assertThat(sender.isEmpty()).isTrue(); allocator.assertHasNoLeaks(); + testRequestInterceptor.expectOnReject(FrameType.REQUEST_FNF, exception).expectNothing(); } static Stream> shouldErrorIfNoAvailabilitySource() { @@ -333,7 +355,9 @@ static Stream> shouldErrorIfNoAvailabilityS /** Ensures single subscription happens in case of racing */ @Test public void shouldSubscribeExactlyOnce1() { - final TestRequesterResponderSupport streamManager = TestRequesterResponderSupport.client(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); final TestDuplexConnection sender = streamManager.getDuplexConnection(); @@ -349,7 +373,7 @@ public void shouldSubscribeExactlyOnce1() { () -> RaceTestUtils.race( () -> { - AtomicReference atomicReference = new AtomicReference(); + AtomicReference atomicReference = new AtomicReference<>(); fireAndForgetRequesterMono.subscribe(null, atomicReference::set); Throwable throwable = atomicReference.get(); if (throwable != null) { @@ -380,6 +404,27 @@ public void shouldSubscribeExactlyOnce1() { stateAssert.isTerminated(); streamManager.assertNoActiveStreams(); + testRequestInterceptor + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .expectNothing(); } Assertions.assertThat(sender.isEmpty()).isTrue(); diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java index 4bff64ec1..a36415cb1 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java @@ -33,10 +33,15 @@ import io.rsocket.RSocket; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.Exceptions; +import io.rsocket.exceptions.RejectedException; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; import io.rsocket.frame.LeaseFrameCodec; import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.frame.SetupFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.subscriber.AssertSubscriber; @@ -64,9 +69,12 @@ import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mockito; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.EmitterProcessor; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; import reactor.test.StepVerifier; class RSocketLeaseTest { @@ -107,7 +115,7 @@ void setUp() { 0, 0, null, - null, + __ -> null, requesterLeaseHandler); mockRSocketHandler = mock(RSocket.class); @@ -145,7 +153,25 @@ void setUp() { Publisher payloadPublisher = a.getArgument(0); return Flux.from(payloadPublisher) .doOnNext(ReferenceCounted::release) - .thenMany(Flux.empty()); + .transform( + Operators.lift( + (__, actual) -> + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + actual.onSubscribe(this); + } + + @Override + protected void hookOnComplete() { + actual.onComplete(); + } + + @Override + protected void hookOnError(Throwable throwable) { + actual.onError(throwable); + } + })); }); rSocketResponder = @@ -157,7 +183,7 @@ void setUp() { 0, FRAME_LENGTH_MASK, Integer.MAX_VALUE, - null); + __ -> null); } @Test @@ -357,32 +383,86 @@ void requesterAvailabilityRespectsTransport() { } @ParameterizedTest - @MethodSource("interactions") - void responderMissingLeaseRequestsAreRejected( - BiFunction> interaction) { + @MethodSource("responderInteractions") + void responderMissingLeaseRequestsAreRejected(FrameType frameType) { ByteBuf buffer = byteBufAllocator.buffer(); buffer.writeCharSequence("test", CharsetUtil.UTF_8); Payload payload1 = ByteBufPayload.create(buffer); - StepVerifier.create(interaction.apply(rSocketResponder, payload1)) - .expectError(MissingLeaseException.class) - .verify(Duration.ofSeconds(5)); + switch (frameType) { + case REQUEST_FNF: + final ByteBuf fnfFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFrame(fnfFrame); + fnfFrame.release(); + break; + case REQUEST_RESPONSE: + final ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFrame(requestResponseFrame); + requestResponseFrame.release(); + break; + case REQUEST_STREAM: + final ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, 1, payload1); + rSocketResponder.handleFrame(requestStreamFrame); + requestStreamFrame.release(); + break; + case REQUEST_CHANNEL: + final ByteBuf requestChannelFrame = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, true, 1, payload1); + rSocketResponder.handleFrame(requestChannelFrame); + requestChannelFrame.release(); + break; + } + + if (frameType != REQUEST_FNF) { + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == ERROR) + .matches(bb -> Exceptions.from(1, bb) instanceof RejectedException) + .matches(ReferenceCounted::release); + } + + byteBufAllocator.assertHasNoLeaks(); } @ParameterizedTest - @MethodSource("interactions") - void responderPresentLeaseRequestsAreAccepted( - BiFunction> interaction, FrameType frameType) { + @MethodSource("responderInteractions") + void responderPresentLeaseRequestsAreAccepted(FrameType frameType) { leaseSender.onNext(Lease.create(5_000, 2)); ByteBuf buffer = byteBufAllocator.buffer(); buffer.writeCharSequence("test", CharsetUtil.UTF_8); Payload payload1 = ByteBufPayload.create(buffer); - Flux.from(interaction.apply(rSocketResponder, payload1)) - .as(StepVerifier::create) - .expectComplete() - .verify(Duration.ofSeconds(5)); + switch (frameType) { + case REQUEST_FNF: + final ByteBuf fnfFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFireAndForget(1, fnfFrame); + fnfFrame.release(); + break; + case REQUEST_RESPONSE: + final ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFrame(requestResponseFrame); + requestResponseFrame.release(); + break; + case REQUEST_STREAM: + final ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, 1, payload1); + rSocketResponder.handleFrame(requestStreamFrame); + requestStreamFrame.release(); + break; + case REQUEST_CHANNEL: + final ByteBuf requestChannelFrame = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, true, 1, payload1); + rSocketResponder.handleFrame(requestChannelFrame); + requestChannelFrame.release(); + break; + } switch (frameType) { case REQUEST_FNF: @@ -400,41 +480,113 @@ void responderPresentLeaseRequestsAreAccepted( } Assertions.assertThat(connection.getSent()) - .hasSize(1) .first() .matches(bb -> FrameHeaderCodec.frameType(bb) == LEASE) .matches(ReferenceCounted::release); + if (frameType != REQUEST_FNF) { + Assertions.assertThat(connection.getSent()) + .hasSize(2) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); + } + byteBufAllocator.assertHasNoLeaks(); } @ParameterizedTest - @MethodSource("interactions") - void responderDepletedAllowedLeaseRequestsAreRejected( - BiFunction> interaction) { + @MethodSource("responderInteractions") + void responderDepletedAllowedLeaseRequestsAreRejected(FrameType frameType) { leaseSender.onNext(Lease.create(5_000, 1)); ByteBuf buffer = byteBufAllocator.buffer(); buffer.writeCharSequence("test", CharsetUtil.UTF_8); Payload payload1 = ByteBufPayload.create(buffer); - Flux responder = Flux.from(interaction.apply(rSocketResponder, payload1)); - responder.subscribe(); + ByteBuf buffer2 = byteBufAllocator.buffer(); + buffer2.writeCharSequence("test2", CharsetUtil.UTF_8); + Payload payload2 = ByteBufPayload.create(buffer2); + + switch (frameType) { + case REQUEST_FNF: + final ByteBuf fnfFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + final ByteBuf fnfFrame2 = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, payload2); + rSocketResponder.handleFrame(fnfFrame); + rSocketResponder.handleFrame(fnfFrame2); + fnfFrame.release(); + fnfFrame2.release(); + break; + case REQUEST_RESPONSE: + final ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + final ByteBuf requestResponseFrame2 = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, payload2); + rSocketResponder.handleFrame(requestResponseFrame); + rSocketResponder.handleFrame(requestResponseFrame2); + requestResponseFrame.release(); + requestResponseFrame2.release(); + break; + case REQUEST_STREAM: + final ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, 1, payload1); + final ByteBuf requestStreamFrame2 = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, 1, payload2); + rSocketResponder.handleFrame(requestStreamFrame); + rSocketResponder.handleFrame(requestStreamFrame2); + requestStreamFrame.release(); + requestStreamFrame2.release(); + break; + case REQUEST_CHANNEL: + final ByteBuf requestChannelFrame = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, true, 1, payload1); + final ByteBuf requestChannelFrame2 = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, true, 1, payload2); + rSocketResponder.handleFrame(requestChannelFrame); + rSocketResponder.handleFrame(requestChannelFrame2); + requestChannelFrame.release(); + requestChannelFrame2.release(); + break; + } + + switch (frameType) { + case REQUEST_FNF: + Mockito.verify(mockRSocketHandler).fireAndForget(any()); + break; + case REQUEST_RESPONSE: + Mockito.verify(mockRSocketHandler).requestResponse(any()); + break; + case REQUEST_STREAM: + Mockito.verify(mockRSocketHandler).requestStream(any()); + break; + case REQUEST_CHANNEL: + Mockito.verify(mockRSocketHandler).requestChannel(any()); + break; + } Assertions.assertThat(connection.getSent()) - .hasSize(1) .first() .matches(bb -> FrameHeaderCodec.frameType(bb) == LEASE) .matches(ReferenceCounted::release); - ByteBuf buffer2 = byteBufAllocator.buffer(); - buffer2.writeCharSequence("test", CharsetUtil.UTF_8); - Payload payload2 = ByteBufPayload.create(buffer2); + if (frameType != REQUEST_FNF) { + Assertions.assertThat(connection.getSent()) + .hasSize(3) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); - Flux.from(interaction.apply(rSocketResponder, payload2)) - .as(StepVerifier::create) - .expectError(MissingLeaseException.class) - .verify(Duration.ofSeconds(5)); + Assertions.assertThat(connection.getSent()) + .hasSize(3) + .element(2) + .matches(bb -> FrameHeaderCodec.frameType(bb) == ERROR) + .matches(bb -> Exceptions.from(1, bb) instanceof RejectedException) + .matches(ReferenceCounted::release); + } + + byteBufAllocator.assertHasNoLeaks(); } @ParameterizedTest @@ -530,4 +682,12 @@ static Stream interactions() { (rSocket, payload) -> rSocket.requestChannel(Mono.just(payload)), FrameType.REQUEST_CHANNEL)); } + + static Stream responderInteractions() { + return Stream.of( + FrameType.REQUEST_FNF, + FrameType.REQUEST_RESPONSE, + FrameType.REQUEST_STREAM, + FrameType.REQUEST_CHANNEL); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java index 570493faa..4c7921db1 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java @@ -77,7 +77,7 @@ void setUp() { 0, 0, null, - null, + __ -> null, RequesterLeaseHandler.None); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java index 6640f8003..9aa8442d9 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -1375,7 +1375,7 @@ protected RSocketRequester newRSocket() { Integer.MAX_VALUE, Integer.MAX_VALUE, null, - null, + (__) -> null, RequesterLeaseHandler.None); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java index 414dbc04b..d796d45e5 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java @@ -63,6 +63,8 @@ import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.subscriber.AssertSubscriber; import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.plugins.TestRequestInterceptor; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.test.util.TestSubscriber; import io.rsocket.util.ByteBufPayload; @@ -269,15 +271,18 @@ protected void hookOnSubscribe(Subscription subscription) { @Test public void checkNoLeaksOnRacingCancelFromRequestChannelAndNextFromUpstream() { ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); for (int i = 0; i < 10000; i++) { AssertSubscriber assertSubscriber = AssertSubscriber.create(); + final MonoProcessor monoProcessor = MonoProcessor.create(); rule.setAcceptingSocket( new RSocket() { @Override public Flux requestChannel(Publisher payloads) { payloads.subscribe(assertSubscriber); - return Flux.never(); + return monoProcessor.flux(); } }, Integer.MAX_VALUE); @@ -303,19 +308,23 @@ public Flux requestChannel(Publisher payloads) { ByteBuf data3 = allocator.buffer(); data3.writeCharSequence("def3", CharsetUtil.UTF_8); ByteBuf nextFrame3 = - PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata3, data3); + PayloadFrameCodec.encode(allocator, 1, false, true, true, metadata3, data3); RaceTestUtils.race( () -> { rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3); }, - assertSubscriber::cancel); + () -> { + assertSubscriber.cancel(); + monoProcessor.onComplete(); + }); Assertions.assertThat(assertSubscriber.values()).allMatch(ReferenceCounted::release); Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnComplete(1).expectNothing(); } } @@ -323,6 +332,8 @@ public Flux requestChannel(Publisher payloads) { public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest() { Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); for (int i = 0; i < 10000; i++) { AssertSubscriber assertSubscriber = AssertSubscriber.create(); @@ -350,11 +361,13 @@ public Flux requestChannel(Publisher payloads) { sink.next(ByteBufPayload.create("d1", "m1")); sink.next(ByteBufPayload.create("d2", "m2")); sink.next(ByteBufPayload.create("d3", "m3")); + sink.complete(); }); Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnCancel(1).expectNothing(); } } @@ -363,6 +376,8 @@ public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChann Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); for (int i = 0; i < 10000; i++) { AssertSubscriber assertSubscriber = AssertSubscriber.create(); @@ -395,11 +410,12 @@ public Flux requestChannel(Publisher payloads) { sink.next(ByteBufPayload.create("d1", "m1")); sink.next(ByteBufPayload.create("d2", "m2")); sink.next(ByteBufPayload.create("d3", "m3")); + sink.complete(); }, parallel); Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); - + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnCancel(1).expectNothing(); rule.assertHasNoLeaks(); } } @@ -410,6 +426,8 @@ public Flux requestChannel(Publisher payloads) { Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); for (int i = 0; i < 10000; i++) { FluxSink[] sinks = new FluxSink[1]; AssertSubscriber assertSubscriber = AssertSubscriber.create(); @@ -499,6 +517,7 @@ public Flux requestChannel(Publisher payloads) { return msg.refCnt() == 0; }); rule.assertHasNoLeaks(); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnError(1).expectNothing(); } } @@ -507,6 +526,8 @@ public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestStrea Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); for (int i = 0; i < 10000; i++) { FluxSink[] sinks = new FluxSink[1]; @@ -536,6 +557,8 @@ public Flux requestStream(Payload payload) { Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); + + testRequestInterceptor.expectOnStart(1, REQUEST_STREAM).expectOnCancel(1).expectNothing(); } } @@ -544,6 +567,8 @@ public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestRespo Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); for (int i = 0; i < 10000; i++) { Operators.MonoSubscriber[] sources = new Operators.MonoSubscriber[1]; @@ -576,6 +601,16 @@ public void subscribe(CoreSubscriber actual) { Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_CANCEL)) + .expectNothing(); } } @@ -795,7 +830,8 @@ public Flux requestChannel(Publisher payloads) { + ERROR + "} but was {" + frameType(rule.connection.getSent().iterator().next()) - + "}"); + + "}") + .matches(ByteBuf::release); } private static Stream refCntCases() { @@ -1178,6 +1214,7 @@ public static class ServerSocketRule extends AbstractSocketRule requestInterceptor); } private void sendRequest(int streamId, FrameType frameType) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java index f50b3ea42..785532bcf 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java @@ -570,7 +570,7 @@ public Flux requestChannel(Publisher payloads) { 0, FRAME_LENGTH_MASK, Integer.MAX_VALUE, - null); + __ -> null); crs = new RSocketRequester( @@ -583,7 +583,7 @@ public Flux requestChannel(Publisher payloads) { 0, 0, null, - null, + __ -> null, RequesterLeaseHandler.None); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java index 4b4311a00..54033e249 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java @@ -72,7 +72,7 @@ public static void setUp() { @ValueSource(strings = {"inbound", "outbound", "inboundCancel"}) public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String completionCase) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); - LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); final TestPublisher publisher = TestPublisher.create(); @@ -373,6 +373,56 @@ public void streamShouldWorkCorrectlyWhenRacingHandleErrorWithSubscription() { } } + @Test + public void streamShouldWorkCorrectlyWhenRacingOutboundErrorWithSubscription() { + RuntimeException exception = new RuntimeException("test"); + + for (int i = 0; i < 10000; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + + RaceTestUtils.race( + () -> requestChannelResponderSubscriber.subscribe(assertSubscriber), + () -> publisher.error(exception)); + + stateAssert.isTerminated(); + + FrameAssert.assertThat(activeStreams.getDuplexConnection().awaitFrame()) + .typeOf(ERROR) + .hasData("test") + .hasStreamId(1) + .hasNoLeaks(); + + if (!assertSubscriber.values().isEmpty()) { + assertSubscriber.assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(p).hasNoLeaks()); + } + + assertSubscriber + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Outbound has terminated with an error"); + + allocator.assertHasNoLeaks(); + } + } + @Test public void streamShouldWorkCorrectlyWhenRacingHandleCancelWithSubscription() { for (int i = 0; i < 10000; i++) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java index 520dd0196..1396774c4 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java @@ -30,6 +30,7 @@ import io.rsocket.frame.FrameType; import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.plugins.TestRequestInterceptor; import io.rsocket.util.ByteBufPayload; import java.time.Duration; import java.util.ArrayList; @@ -39,11 +40,10 @@ import java.util.stream.Stream; import org.assertj.core.api.Assertions; import org.assertj.core.api.Assumptions; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; -import org.junit.jupiter.params.provider.ValueSource; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Hooks; @@ -169,8 +169,9 @@ public String toString() { @MethodSource("scenarios") public void shouldSubscribeExactlyOnce(Scenario scenario) { for (int i = 0; i < 10000; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); final TestRequesterResponderSupport requesterResponderSupport = - TestRequesterResponderSupport.client(); + TestRequesterResponderSupport.client(testRequestInterceptor); final Supplier payloadSupplier = () -> TestRequesterResponderSupport.genericPayload( @@ -214,6 +215,9 @@ public void shouldSubscribeExactlyOnce(Scenario scenario) { if (requestOperator instanceof FrameHandler) { ((FrameHandler) requestOperator).handleComplete(); + if (scenario.requestType() == REQUEST_CHANNEL) { + ((FrameHandler) requestOperator).handleCancel(); + } } }) .thenCancel() @@ -240,6 +244,29 @@ public void shouldSubscribeExactlyOnce(Scenario scenario) { stepVerifier.verify(Duration.ofSeconds(1)); requesterResponderSupport.getAllocator().assertHasNoLeaks(); + if (scenario.requestType() != METADATA_PUSH) { + testRequestInterceptor + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .expectNothing(); + } } } @@ -251,7 +278,9 @@ public void shouldSentRequestFrameOnceInCaseOfRequestRacing(Scenario scenario) { .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); for (int i = 0; i < 10000; i++) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); final Supplier payloadSupplier = () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); @@ -316,6 +345,12 @@ public void shouldSentRequestFrameOnceInCaseOfRequestRacing(Scenario scenario) { activeStreams.assertNoActiveStreams(); Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); activeStreams.getAllocator().assertHasNoLeaks(); + if (scenario.requestType() != METADATA_PUSH) { + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + } } } @@ -330,7 +365,9 @@ public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(Scenario scenario) { .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); for (int i = 0; i < 10000; i++) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); final Supplier payloadSupplier = () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); @@ -404,6 +441,16 @@ public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(Scenario scenario) { .hasClientSideStreamId() .hasStreamId(1) .hasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } else { + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); } Assertions.assertThat(responsePayload.release()).isTrue(); @@ -419,22 +466,24 @@ public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(Scenario scenario) { * Ensures that in case of racing between next element and cancel we will not have any memory * leaks */ - @Test - public void shouldHaveNoLeaksOnNextAndCancelRacing() { + @ParameterizedTest(name = "Should have no leaks when {0} is canceled during reassembly") + @MethodSource("scenarios") + public void shouldHaveNoLeaksOnNextAndCancelRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + for (int i = 0; i < 10000; i++) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); - final Payload payload = - TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); - final RequestResponseRequesterMono requestResponseRequesterMono = - new RequestResponseRequesterMono(payload, activeStreams); + final Publisher requestOperator = scenario.requestOperator(payloadSupplier, activeStreams); Payload response = ByteBufPayload.create("test", "test"); - - StepVerifier.create(requestResponseRequesterMono.doOnNext(Payload::release)) - .expectSubscription() - .expectComplete() - .verifyLater(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + requestOperator.subscribe((AssertSubscriber) assertSubscriber); final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); FrameAssert.assertThat(sentFrame) @@ -446,16 +495,16 @@ public void shouldHaveNoLeaksOnNextAndCancelRacing() { .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) .hasData(TestRequesterResponderSupport.DATA_CONTENT) .hasNoFragmentsFollow() - .typeOf(FrameType.REQUEST_RESPONSE) + .typeOf(scenario.requestType()) .hasClientSideStreamId() .hasStreamId(1) .hasNoLeaks(); RaceTestUtils.race( - requestResponseRequesterMono::cancel, - () -> requestResponseRequesterMono.handlePayload(response)); + ((Subscription) requestOperator)::cancel, + () -> ((RequesterFrameHandler) requestOperator).handlePayload(response)); - Assertions.assertThat(payload.refCnt()).isZero(); + assertSubscriber.values().forEach(Payload::release); Assertions.assertThat(response.refCnt()).isZero(); activeStreams.assertNoActiveStreams(); @@ -468,10 +517,19 @@ public void shouldHaveNoLeaksOnNextAndCancelRacing() { .hasClientSideStreamId() .hasStreamId(1) .hasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } else { + assertSubscriber.assertTerminated(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); } Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); - - StateAssert.assertThat(requestResponseRequesterMono).isTerminated(); activeStreams.getAllocator().assertHasNoLeaks(); } } @@ -482,84 +540,106 @@ public void shouldHaveNoLeaksOnNextAndCancelRacing() { * cancel we will not have any memory leaks */ @ParameterizedTest - @ValueSource(booleans = {false, true}) - public void shouldHaveNoUnexpectedErrorDuringOnErrorAndCancelRacing(boolean withReassembly) { + @MethodSource("scenarios") + public void shouldHaveNoUnexpectedErrorDuringOnErrorAndCancelRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + boolean[] withReassemblyOptions = new boolean[] {true, false}; final ArrayList droppedErrors = new ArrayList<>(); Hooks.onErrorDropped(droppedErrors::add); - try { - for (int i = 0; i < 10000; i++) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); - final Payload payload = - TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); - - final RequestResponseRequesterMono requestResponseRequesterMono = - new RequestResponseRequesterMono(payload, activeStreams); - - final StateAssert stateAssert = - StateAssert.assertThat(requestResponseRequesterMono); - - stateAssert.isUnsubscribed(); - final AssertSubscriber assertSubscriber = - requestResponseRequesterMono.subscribeWith(AssertSubscriber.create(0)); - stateAssert.hasSubscribedFlagOnly(); - - assertSubscriber.request(1); - - stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + try { + for (boolean withReassembly : withReassemblyOptions) { + for (int i = 0; i < 10000; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = + scenario.requestOperator(payloadSupplier, activeStreams); + + final StateAssert stateAssert; + if (requestOperator instanceof RequestResponseRequesterMono) { + stateAssert = StateAssert.assertThat((RequestResponseRequesterMono) requestOperator); + } else if (requestOperator instanceof RequestStreamRequesterFlux) { + stateAssert = StateAssert.assertThat((RequestStreamRequesterFlux) requestOperator); + } else { + stateAssert = StateAssert.assertThat((RequestChannelRequesterFlux) requestOperator); + } - final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); - FrameAssert.assertThat(sentFrame) - .isNotNull() - .hasPayloadSize( - TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length - + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) - .length) - .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) - .hasData(TestRequesterResponderSupport.DATA_CONTENT) - .hasNoFragmentsFollow() - .typeOf(FrameType.REQUEST_RESPONSE) - .hasClientSideStreamId() - .hasStreamId(1) - .hasNoLeaks(); + stateAssert.isUnsubscribed(); + final AssertSubscriber assertSubscriber = AssertSubscriber.create(0); - if (withReassembly) { - final ByteBuf fragmentBuf = - activeStreams.getAllocator().buffer().writeBytes(new byte[] {1, 2, 3}); - requestResponseRequesterMono.handleNext(fragmentBuf, true, false); - // mimic frameHandler behaviour - fragmentBuf.release(); - } + requestOperator.subscribe((AssertSubscriber) assertSubscriber); - final RuntimeException testException = new RuntimeException("test"); - RaceTestUtils.race( - requestResponseRequesterMono::cancel, - () -> requestResponseRequesterMono.handleError(testException)); + stateAssert.hasSubscribedFlagOnly(); - Assertions.assertThat(payload.refCnt()).isZero(); + assertSubscriber.request(1); - activeStreams.assertNoActiveStreams(); - stateAssert.isTerminated(); + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); - final boolean isEmpty = activeStreams.getDuplexConnection().isEmpty(); - if (!isEmpty) { - final ByteBuf cancellationFrame = activeStreams.getDuplexConnection().awaitFrame(); - FrameAssert.assertThat(cancellationFrame) + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(sentFrame) .isNotNull() - .typeOf(FrameType.CANCEL) + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) .hasClientSideStreamId() .hasStreamId(1) .hasNoLeaks(); - Assertions.assertThat(droppedErrors).containsExactly(testException); - } else { - assertSubscriber.assertTerminated().assertErrorMessage("test"); - } - Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + if (withReassembly) { + final ByteBuf fragmentBuf = + activeStreams.getAllocator().buffer().writeBytes(new byte[] {1, 2, 3}); + ((RequesterFrameHandler) requestOperator).handleNext(fragmentBuf, true, false); + // mimic frameHandler behaviour + fragmentBuf.release(); + } - stateAssert.isTerminated(); - droppedErrors.clear(); - activeStreams.getAllocator().assertHasNoLeaks(); + final RuntimeException testException = new RuntimeException("test"); + RaceTestUtils.race( + ((Subscription) requestOperator)::cancel, + () -> ((RequesterFrameHandler) requestOperator).handleError(testException)); + + activeStreams.assertNoActiveStreams(); + stateAssert.isTerminated(); + + final boolean isEmpty = activeStreams.getDuplexConnection().isEmpty(); + if (!isEmpty) { + final ByteBuf cancellationFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(cancellationFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(droppedErrors).containsExactly(testException); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } else { + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnError(1) + .expectNothing(); + + assertSubscriber.assertTerminated().assertErrorMessage("test"); + } + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + + stateAssert.isTerminated(); + droppedErrors.clear(); + activeStreams.getAllocator().assertHasNoLeaks(); + } } } finally { Hooks.resetOnErrorDropped(); @@ -583,20 +663,25 @@ public void shouldHaveNoUnexpectedErrorDuringOnErrorAndCancelRacing(boolean with * *

Ensures full serialization of outgoing signal (frames) */ - @Test - public void shouldBeConsistentInCaseOfRacingOfCancellationAndRequest() { + @ParameterizedTest + @MethodSource("scenarios") + public void shouldBeConsistentInCaseOfRacingOfCancellationAndRequest(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); for (int i = 0; i < 10000; i++) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); - final Payload payload = - TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); - final RequestResponseRequesterMono requestResponseRequesterMono = - new RequestResponseRequesterMono(payload, activeStreams); + final Publisher requestOperator = scenario.requestOperator(payloadSupplier, activeStreams); Payload response = ByteBufPayload.create("test", "test"); - final AssertSubscriber assertSubscriber = - requestResponseRequesterMono.subscribeWith(new AssertSubscriber<>(0)); + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + requestOperator.subscribe((AssertSubscriber) assertSubscriber); RaceTestUtils.race(() -> assertSubscriber.cancel(), () -> assertSubscriber.request(1)); @@ -604,11 +689,7 @@ public void shouldBeConsistentInCaseOfRacingOfCancellationAndRequest() { final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); FrameAssert.assertThat(sentFrame) .isNotNull() - .typeOf(FrameType.REQUEST_RESPONSE) - .hasPayloadSize( - TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length - + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) - .length) + .typeOf(scenario.requestType()) .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) .hasData(TestRequesterResponderSupport.DATA_CONTENT) .hasNoFragmentsFollow() @@ -623,15 +704,17 @@ public void shouldBeConsistentInCaseOfRacingOfCancellationAndRequest() { .hasClientSideStreamId() .hasStreamId(1) .hasNoLeaks(); - } - Assertions.assertThat(payload.refCnt()).isZero(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } - StateAssert.assertThat(requestResponseRequesterMono).isTerminated(); + ((RequesterFrameHandler) requestOperator).handlePayload(response); + assertSubscriber.values().forEach(Payload::release); - requestResponseRequesterMono.handlePayload(response); Assertions.assertThat(response.refCnt()).isZero(); - activeStreams.assertNoActiveStreams(); Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); activeStreams.getAllocator().assertHasNoLeaks(); @@ -639,20 +722,26 @@ public void shouldBeConsistentInCaseOfRacingOfCancellationAndRequest() { } /** Ensures that CancelFrame is sent exactly once in case of racing between cancel() methods */ - @Test - public void shouldSentCancelFrameExactlyOnce() { + @ParameterizedTest + @MethodSource("scenarios") + public void shouldSentCancelFrameExactlyOnce(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); for (int i = 0; i < 10000; i++) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); - final Payload payload = - TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); - final RequestResponseRequesterMono requestResponseRequesterMono = - new RequestResponseRequesterMono(payload, activeStreams); + final Publisher requesterOperator = + scenario.requestOperator(payloadSupplier, activeStreams); Payload response = ByteBufPayload.create("test", "test"); - final AssertSubscriber assertSubscriber = - requestResponseRequesterMono.subscribeWith(new AssertSubscriber<>(0)); + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + requesterOperator.subscribe((AssertSubscriber) assertSubscriber); assertSubscriber.request(1); @@ -660,19 +749,15 @@ public void shouldSentCancelFrameExactlyOnce() { FrameAssert.assertThat(sentFrame) .isNotNull() .hasNoFragmentsFollow() - .typeOf(FrameType.REQUEST_RESPONSE) + .typeOf(scenario.requestType()) .hasClientSideStreamId() - .hasPayloadSize( - TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length - + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) - .length) .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) .hasData(TestRequesterResponderSupport.DATA_CONTENT) .hasStreamId(1) .hasNoLeaks(); RaceTestUtils.race( - requestResponseRequesterMono::cancel, requestResponseRequesterMono::cancel); + ((Subscription) requesterOperator)::cancel, ((Subscription) requesterOperator)::cancel); final ByteBuf cancelFrame = activeStreams.getDuplexConnection().awaitFrame(); FrameAssert.assertThat(cancelFrame) @@ -682,15 +767,18 @@ public void shouldSentCancelFrameExactlyOnce() { .hasStreamId(1) .hasNoLeaks(); - Assertions.assertThat(payload.refCnt()).isZero(); - activeStreams.assertNoActiveStreams(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); - StateAssert.assertThat(requestResponseRequesterMono).isTerminated(); + activeStreams.assertNoActiveStreams(); - requestResponseRequesterMono.handlePayload(response); + ((RequesterFrameHandler) requesterOperator).handlePayload(response); + assertSubscriber.values().forEach(Payload::release); Assertions.assertThat(response.refCnt()).isZero(); - requestResponseRequesterMono.handleComplete(); + ((RequesterFrameHandler) requesterOperator).handleComplete(); assertSubscriber.assertNotTerminated(); activeStreams.assertNoActiveStreams(); diff --git a/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java b/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java index 270bc4a05..4f7821e4a 100755 --- a/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java @@ -20,6 +20,7 @@ import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; import static io.rsocket.frame.FrameType.REQUEST_FNF; import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; import io.netty.buffer.ByteBuf; import io.rsocket.FrameAssert; @@ -29,6 +30,8 @@ import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.FrameType; import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.plugins.TestRequestInterceptor; import io.rsocket.test.util.TestDuplexConnection; import java.util.ArrayList; import java.util.concurrent.ThreadLocalRandom; @@ -86,6 +89,12 @@ public ResponderFrameHandler responseOperator( new RequestResponseResponderSubscriber( streamId, firstFragment, streamManager, handler); streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_RESPONSE, null); + } + return subscriber; } @@ -99,6 +108,12 @@ public ResponderFrameHandler responseOperator( RequestResponseResponderSubscriber subscriber = new RequestResponseResponderSubscriber(streamId, streamManager); streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_RESPONSE, null); + } + return handler.requestResponse(firstPayload).subscribeWith(subscriber); } @@ -128,6 +143,12 @@ public ResponderFrameHandler responseOperator( RequestStreamResponderSubscriber subscriber = new RequestStreamResponderSubscriber( streamId, initialRequestN, firstFragment, streamManager, handler); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_STREAM, null); + } + streamManager.activeStreams.put(streamId, subscriber); return subscriber; } @@ -142,6 +163,12 @@ public ResponderFrameHandler responseOperator( RequestStreamResponderSubscriber subscriber = new RequestStreamResponderSubscriber(streamId, initialRequestN, streamManager); streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_STREAM, null); + } + return handler.requestStream(firstPayload).subscribeWith(subscriber); } @@ -172,6 +199,12 @@ public ResponderFrameHandler responseOperator( new RequestChannelResponderSubscriber( streamId, initialRequestN, firstFragment, streamManager, handler); streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_CHANNEL, null); + } + return subscriber; } @@ -186,6 +219,12 @@ public ResponderFrameHandler responseOperator( new RequestChannelResponderSubscriber( streamId, initialRequestN, firstPayload, streamManager); streamManager.activeStreams.put(streamId, responderSubscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_CHANNEL, null); + } + return handler.requestChannel(responderSubscriber).subscribeWith(responderSubscriber); } @@ -242,8 +281,9 @@ public Flux requestChannel(Publisher payloads) { void shouldHandleRequest(Scenario scenario) { Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); TestRequesterResponderSupport testRequesterResponderSupport = - TestRequesterResponderSupport.client(); + TestRequesterResponderSupport.client(testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); final TestDuplexConnection sender = testRequesterResponderSupport.getDuplexConnection(); TestPublisher testPublisher = TestPublisher.create(); @@ -286,6 +326,9 @@ void shouldHandleRequest(Scenario scenario) { .hasStreamId(1) .hasRequestN(1) .hasNoLeaks(); + + responderFrameHandler.handleComplete(); + testHandler.consumer.assertComplete(); } } @@ -294,6 +337,10 @@ void shouldHandleRequest(Scenario scenario) { .assertValueCount(1) .assertValuesWith(p -> PayloadAssert.assertThat(p).hasNoLeaks()); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); allocator.assertHasNoLeaks(); } @@ -302,8 +349,9 @@ void shouldHandleRequest(Scenario scenario) { void shouldHandleFragmentedRequest(Scenario scenario) { Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); TestRequesterResponderSupport testRequesterResponderSupport = - TestRequesterResponderSupport.client(); + TestRequesterResponderSupport.client(testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); final TestDuplexConnection sender = testRequesterResponderSupport.getDuplexConnection(); TestPublisher testPublisher = TestPublisher.create(); @@ -370,6 +418,11 @@ void shouldHandleFragmentedRequest(Scenario scenario) { firstPayload.release(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + allocator.assertHasNoLeaks(); } @@ -378,8 +431,9 @@ void shouldHandleFragmentedRequest(Scenario scenario) { void shouldHandleInterruptedFragmentation(Scenario scenario) { Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); TestRequesterResponderSupport testRequesterResponderSupport = - TestRequesterResponderSupport.client(); + TestRequesterResponderSupport.client(testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); TestPublisher testPublisher = TestPublisher.create(); TestHandler testHandler = new TestHandler(testPublisher, new AssertSubscriber<>(0)); @@ -413,6 +467,11 @@ void shouldHandleInterruptedFragmentation(Scenario scenario) { testPublisher.assertWasNotSubscribed(); testRequesterResponderSupport.assertNoActiveStreams(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + allocator.assertHasNoLeaks(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java index 202ea8279..fe3f75e1b 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java @@ -61,7 +61,7 @@ void requesterStreamsTerminatedOnZeroErrorFrame() { 0, 0, null, - null, + __ -> null, RequesterLeaseHandler.None); String errorMsg = "error"; @@ -99,7 +99,7 @@ void requesterNewStreamsTerminatedAfterZeroErrorFrame() { 0, 0, null, - null, + __ -> null, RequesterLeaseHandler.None); conn.addToReceivedBuffer( diff --git a/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java b/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java index 332a4433e..e282d72d5 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java +++ b/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java @@ -23,6 +23,7 @@ import io.netty.util.CharsetUtil; import io.rsocket.DuplexConnection; import io.rsocket.Payload; +import io.rsocket.RSocket; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.FrameType; import io.rsocket.frame.decoder.PayloadDecoder; @@ -35,7 +36,7 @@ import reactor.core.Exceptions; import reactor.util.annotation.Nullable; -final class TestRequesterResponderSupport extends RequesterResponderSupport { +final class TestRequesterResponderSupport extends RequesterResponderSupport implements RSocket { static final String DATA_CONTENT = "testData"; static final String METADATA_CONTENT = "testMetadata"; @@ -57,7 +58,7 @@ final class TestRequesterResponderSupport extends RequesterResponderSupport { PayloadDecoder.ZERO_COPY, connection, streamIdSupplier, - requestInterceptor); + (__) -> requestInterceptor); this.error = error; } @@ -173,6 +174,18 @@ public synchronized int addAndGetNextStreamId(FrameHandler frameHandler) { return nextStreamId; } + public static TestRequesterResponderSupport client( + @Nullable Throwable e, @Nullable RequestInterceptor requestInterceptor) { + return client( + new TestDuplexConnection( + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT)), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + requestInterceptor, + e); + } + public static TestRequesterResponderSupport client(@Nullable Throwable e) { return client(0, FRAME_LENGTH_MASK, Integer.MAX_VALUE, e); } @@ -241,6 +254,16 @@ public static TestRequesterResponderSupport client() { return client(0); } + public static TestRequesterResponderSupport client(RequestInterceptor requestInterceptor) { + return client( + new TestDuplexConnection( + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT)), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + requestInterceptor); + } + public TestRequesterResponderSupport assertNoActiveStreams() { Assertions.assertThat(activeStreams).isEmpty(); return this; diff --git a/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java b/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java new file mode 100644 index 000000000..24a035b78 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java @@ -0,0 +1,754 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.FrameType; +import io.rsocket.transport.local.LocalClientTransport; +import io.rsocket.transport.local.LocalServerTransport; +import io.rsocket.util.DefaultPayload; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.util.annotation.Nullable; + +public class RequestInterceptorTest { + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheClientRequesterSide(boolean errorOutcome) { + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .bindNow(LocalServerTransport.create("test")); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final RSocket rSocket = + RSocketConnector.create() + .interceptors( + ir -> + ir.forRequester( + (Function) + (__) -> testRequestInterceptor)) + .connect(LocalClientTransport.create("test")) + .block(); + + try { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel(Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + } finally { + rSocket.dispose(); + closeable.dispose(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheClientResponderSide(boolean errorOutcome) + throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + final Closeable closeable = + RSocketServer.create( + (setup, rSocket) -> + Mono.just(new RSocket() {}) + .doAfterTerminate( + () -> { + new Thread( + () -> { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel( + Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + latch.countDown(); + }) + .start(); + })) + .bindNow(LocalServerTransport.create("test")); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final RSocket rSocket = + RSocketConnector.create() + .acceptor( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .interceptors( + ir -> + ir.forResponder( + (Function) + (__) -> testRequestInterceptor)) + .connect(LocalClientTransport.create("test")) + .block(); + + try { + Assertions.assertThat(latch.await(1, TimeUnit.SECONDS)).isTrue(); + + testRequestInterceptor + .expectOnStart(2, FrameType.REQUEST_FNF) + .expectOnComplete(2) + .expectOnStart(4, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 4) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(6, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 6) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(8, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 8) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + + } finally { + rSocket.dispose(); + closeable.dispose(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheServerRequesterSide(boolean errorOutcome) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .interceptors( + ir -> + ir.forResponder( + (Function) + (__) -> testRequestInterceptor)) + .bindNow(LocalServerTransport.create("test")); + final RSocket rSocket = + RSocketConnector.create().connect(LocalClientTransport.create("test")).block(); + + try { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel(Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + } finally { + rSocket.dispose(); + closeable.dispose(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheServerResponderSide(boolean errorOutcome) + throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final Closeable closeable = + RSocketServer.create( + (setup, rSocket) -> + Mono.just(new RSocket() {}) + .doAfterTerminate( + () -> { + new Thread( + () -> { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel( + Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + latch.countDown(); + }) + .start(); + })) + .interceptors( + ir -> + ir.forRequester( + (Function) + (__) -> testRequestInterceptor)) + .bindNow(LocalServerTransport.create("test")); + final RSocket rSocket = + RSocketConnector.create() + .acceptor( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .connect(LocalClientTransport.create("test")) + .block(); + + try { + Assertions.assertThat(latch.await(1, TimeUnit.SECONDS)).isTrue(); + + testRequestInterceptor + .expectOnStart(2, FrameType.REQUEST_FNF) + .expectOnComplete(2) + .expectOnStart(4, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 4) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(6, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 6) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(8, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 8) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + + } finally { + rSocket.dispose(); + closeable.dispose(); + } + } + + @Test + void ensuresExceptionInTheInterceptorIsHandledProperly() { + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads); + } + })) + .bindNow(LocalServerTransport.create("test")); + + final RequestInterceptor testRequestInterceptor = + new RequestInterceptor() { + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnStart"); + } + + @Override + public void onTerminate(int streamId, @Nullable Throwable terminalSignal) { + throw new RuntimeException("testOnTerminate"); + } + + @Override + public void onCancel(int streamId) { + throw new RuntimeException("testOnCancel"); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnReject"); + } + + @Override + public void dispose() {} + }; + final RSocket rSocket = + RSocketConnector.create() + .interceptors( + ir -> + ir.forRequester( + (Function) + (__) -> testRequestInterceptor)) + .connect(LocalClientTransport.create("test")) + .block(); + + try { + StepVerifier.create(rSocket.fireAndForget(DefaultPayload.create("test"))) + .expectSubscription() + .expectComplete() + .verify(); + + StepVerifier.create(rSocket.requestResponse(DefaultPayload.create("test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(); + + StepVerifier.create(rSocket.requestStream(DefaultPayload.create("test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(); + + StepVerifier.create(rSocket.requestChannel(Flux.just(DefaultPayload.create("test")))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(); + } finally { + rSocket.dispose(); + closeable.dispose(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void shouldSupportMultipleInterceptors(boolean errorOutcome) { + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .bindNow(LocalServerTransport.create("test")); + + final RequestInterceptor testRequestInterceptor1 = + new RequestInterceptor() { + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnStart"); + } + + @Override + public void onTerminate(int streamId, @Nullable Throwable terminalSignal) { + throw new RuntimeException("testOnTerminate"); + } + + @Override + public void onCancel(int streamId) { + throw new RuntimeException("testOnTerminate"); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnReject"); + } + + @Override + public void dispose() {} + }; + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequestInterceptor testRequestInterceptor2 = new TestRequestInterceptor(); + final RSocket rSocket = + RSocketConnector.create() + .interceptors( + ir -> + ir.forRequester( + (Function) + (__) -> testRequestInterceptor) + .forRequester( + (Function) + (__) -> testRequestInterceptor1) + .forRequester( + (Function) + (__) -> testRequestInterceptor2)) + .connect(LocalClientTransport.create("test")) + .block(); + + try { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel(Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + + testRequestInterceptor2 + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + } finally { + rSocket.dispose(); + closeable.dispose(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java b/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java index 1174acedc..fe9de7ce1 100644 --- a/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java +++ b/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java @@ -4,8 +4,10 @@ import io.rsocket.frame.FrameType; import io.rsocket.internal.jctools.queues.MpscUnboundedArrayQueue; import java.util.Queue; +import java.util.function.Consumer; import org.assertj.core.api.Assertions; -import reactor.core.publisher.SignalType; +import org.assertj.core.api.Condition; +import reactor.util.annotation.Nullable; public class TestRequestInterceptor implements RequestInterceptor { @@ -15,18 +17,25 @@ public class TestRequestInterceptor implements RequestInterceptor { public void dispose() {} @Override - public void onStart(int streamId, FrameType requestType, ByteBuf metadata) { + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { events.add(new Event(EventType.ON_START, streamId, requestType, null)); } @Override - public void onEnd(int streamId, SignalType terminalSignal) { - events.add(new Event(EventType.ON_END, streamId, null, terminalSignal)); + public void onTerminate(int streamId, @Nullable Throwable t) { + events.add( + new Event(t == null ? EventType.ON_COMPLETE : EventType.ON_ERROR, streamId, null, t)); } @Override - public void onReject(int streamId, FrameType requestType, ByteBuf metadata) { - events.add(new Event(EventType.ON_REJECT, streamId, requestType, null)); + public void onCancel(int streamId) { + events.add(new Event(EventType.ON_CANCEL, streamId, null, null)); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + events.add(new Event(EventType.ON_REJECT, -1, requestType, rejectionReason)); } public TestRequestInterceptor expectOnStart(int streamId, FrameType requestType) { @@ -40,13 +49,62 @@ public TestRequestInterceptor expectOnStart(int streamId, FrameType requestType) return this; } - public TestRequestInterceptor expectOnEnd(int streamId, SignalType signalType) { + public TestRequestInterceptor expectOnComplete(int streamId) { final Event event = events.poll(); Assertions.assertThat(event) - .hasFieldOrPropertyWithValue("eventType", EventType.ON_END) - .hasFieldOrPropertyWithValue("streamId", streamId) - .hasFieldOrPropertyWithValue("signalType", signalType); + .hasFieldOrPropertyWithValue("eventType", EventType.ON_COMPLETE) + .hasFieldOrPropertyWithValue("streamId", streamId); + + return this; + } + + public TestRequestInterceptor expectOnError(int streamId) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_ERROR) + .hasFieldOrPropertyWithValue("streamId", streamId); + + return this; + } + + public TestRequestInterceptor expectOnCancel(int streamId) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_CANCEL) + .hasFieldOrPropertyWithValue("streamId", streamId); + + return this; + } + + public TestRequestInterceptor assertNext(Consumer consumer) { + final Event event = events.poll(); + Assertions.assertThat(event).isNotNull(); + + consumer.accept(event); + + return this; + } + + public TestRequestInterceptor expectOnReject(FrameType requestType, Throwable rejectionReason) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_REJECT) + .has( + new Condition<>( + e -> { + Assertions.assertThat(e.error) + .isExactlyInstanceOf(rejectionReason.getClass()) + .hasMessage(rejectionReason.getMessage()) + .hasCause(rejectionReason.getCause()); + return true; + }, + "Has rejection reason which matches to %s", + rejectionReason)) + .hasFieldOrPropertyWithValue("requestType", requestType); return this; } @@ -59,23 +117,25 @@ public TestRequestInterceptor expectNothing() { return this; } - static final class Event { - final EventType eventType; - final int streamId; - final FrameType requestType; - final SignalType signalType; + public static final class Event { + public final EventType eventType; + public final int streamId; + public final FrameType requestType; + public final Throwable error; - Event(EventType eventType, int streamId, FrameType requestType, SignalType signalType) { + Event(EventType eventType, int streamId, FrameType requestType, Throwable error) { this.eventType = eventType; this.streamId = streamId; this.requestType = requestType; - this.signalType = signalType; + this.error = error; } } - enum EventType { + public enum EventType { ON_START, - ON_END, + ON_COMPLETE, + ON_ERROR, + ON_CANCEL, ON_REJECT } }