diff --git a/rsocket-core/build.gradle b/rsocket-core/build.gradle index 41adbd7a8..53a896aea 100644 --- a/rsocket-core/build.gradle +++ b/rsocket-core/build.gradle @@ -29,6 +29,7 @@ dependencies { implementation 'org.slf4j:slf4j-api' + testImplementation (project(":rsocket-transport-local")) testImplementation 'io.projectreactor:reactor-test' testImplementation 'org.assertj:assertj-core' testImplementation 'org.junit.jupiter:junit-jupiter-api' diff --git a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java index e51c3e75f..dec946bab 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java +++ b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java @@ -25,6 +25,7 @@ import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.frame.FrameType; +import io.rsocket.plugins.RequestInterceptor; import java.time.Duration; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import org.reactivestreams.Subscription; @@ -51,6 +52,8 @@ final class FireAndForgetRequesterMono extends Mono implements Subscriptio final RequesterResponderSupport requesterResponderSupport; final DuplexConnection connection; + @Nullable final RequestInterceptor requestInterceptor; + FireAndForgetRequesterMono(Payload payload, RequesterResponderSupport requesterResponderSupport) { this.allocator = requesterResponderSupport.getAllocator(); this.payload = payload; @@ -58,14 +61,22 @@ final class FireAndForgetRequesterMono extends Mono implements Subscriptio this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); } @Override public void subscribe(CoreSubscriber actual) { long previousState = markSubscribed(STATE, this); if (isSubscribedOrTerminated(previousState)) { - Operators.error( - actual, new IllegalStateException("FireAndForgetMono allows only a single Subscriber")); + final IllegalStateException e = + new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + Operators.error(actual, e); return; } @@ -76,14 +87,28 @@ public void subscribe(CoreSubscriber actual) { try { if (!isValid(mtu, this.maxFrameLength, p, false)) { lazyTerminate(STATE, this); - p.release(); - actual.onError( + + final IllegalArgumentException e = new IllegalArgumentException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + actual.onError(e); return; } } catch (IllegalReferenceCountException e) { lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + actual.onError(e); return; } @@ -93,14 +118,32 @@ public void subscribe(CoreSubscriber actual) { streamId = this.requesterResponderSupport.getNextStreamId(); } catch (Throwable t) { lazyTerminate(STATE, this); + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_FNF, p.metadata()); + } + p.release(); - actual.onError(Exceptions.unwrap(t)); + + actual.onError(ut); return; } + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.metadata()); + } + try { if (isTerminated(this.state)) { p.release(); + + if (interceptor != null) { + interceptor.onCancel(streamId); + } + return; } @@ -108,11 +151,21 @@ public void subscribe(CoreSubscriber actual) { streamId, FrameType.REQUEST_FNF, mtu, p, this.connection, this.allocator, true); } catch (Throwable e) { lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, e); + } + actual.onError(e); return; } lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, null); + } + actual.onComplete(); } @@ -137,19 +190,41 @@ public Void block(Duration m) { public Void block() { long previousState = markSubscribed(STATE, this); if (isSubscribedOrTerminated(previousState)) { - throw new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + final IllegalStateException e = + new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + throw e; } final Payload p = this.payload; try { if (!isValid(this.mtu, this.maxFrameLength, p, false)) { lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, p.metadata()); + } + p.release(); - throw new IllegalArgumentException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + + throw e; } } catch (IllegalReferenceCountException e) { lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + throw Exceptions.propagate(e); } @@ -158,10 +233,22 @@ public Void block() { streamId = this.requesterResponderSupport.getNextStreamId(); } catch (Throwable t) { lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(Exceptions.unwrap(t), FrameType.REQUEST_FNF, p.metadata()); + } + p.release(); + throw Exceptions.propagate(t); } + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.metadata()); + } + try { sendReleasingPayload( streamId, @@ -173,10 +260,20 @@ public Void block() { true); } catch (Throwable e) { lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, e); + } + throw Exceptions.propagate(e); } lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, null); + } + return null; } diff --git a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java index 3a2363d47..889c98fde 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java @@ -22,11 +22,13 @@ import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.CoreSubscriber; import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; final class FireAndForgetResponderSubscriber implements CoreSubscriber, ResponderFrameHandler { @@ -42,6 +44,8 @@ final class FireAndForgetResponderSubscriber final RSocket handler; final int maxInboundPayloadSize; + @Nullable final RequestInterceptor requestInterceptor; + CompositeByteBuf frames; private FireAndForgetResponderSubscriber() { @@ -51,6 +55,19 @@ private FireAndForgetResponderSubscriber() { this.maxInboundPayloadSize = 0; this.requesterResponderSupport = null; this.handler = null; + this.requestInterceptor = null; + this.frames = null; + } + + FireAndForgetResponderSubscriber( + int streamId, RequesterResponderSupport requesterResponderSupport) { + this.streamId = streamId; + this.allocator = null; + this.payloadDecoder = null; + this.maxInboundPayloadSize = 0; + this.requesterResponderSupport = null; + this.handler = null; + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.frames = null; } @@ -65,6 +82,7 @@ private FireAndForgetResponderSubscriber() { this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); this.requesterResponderSupport = requesterResponderSupport; this.handler = handler; + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.frames = ReassemblyUtils.addFollowingFrame( @@ -81,11 +99,21 @@ public void onNext(Void voidVal) {} @Override public void onError(Throwable t) { + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(this.streamId, t); + } + logger.debug("Dropped Outbound error", t); } @Override - public void onComplete() {} + public void onComplete() { + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(this.streamId, null); + } + } @Override public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLastPayload) { @@ -95,11 +123,17 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas ReassemblyUtils.addFollowingFrame( frames, followingFrame, hasFollows, this.maxInboundPayloadSize); } catch (IllegalStateException t) { - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); this.frames = null; frames.release(); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, t); + } + logger.debug("Reassembly has failed", t); return; } @@ -114,6 +148,12 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas frames.release(); } catch (Throwable t) { ReferenceCountUtil.safeRelease(frames); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(this.streamId, t); + } + logger.debug("Reassembly has failed", t); return; } @@ -127,9 +167,16 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas public final void handleCancel() { final CompositeByteBuf frames = this.frames; if (frames != null) { - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + this.frames = null; frames.release(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId); + } } } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java index 05860476d..342fd9480 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java @@ -631,6 +631,7 @@ public Mono connect(Supplier transportSupplier) { (int) keepAliveInterval.toMillis(), (int) keepAliveMaxLifeTime.toMillis(), keepAliveHandler, + interceptors::initRequesterRequestInterceptor, requesterLeaseHandler); RSocket wrappedRSocketRequester = @@ -669,7 +670,8 @@ public Mono connect(Supplier transportSupplier) { responderLeaseHandler, mtu, maxFrameLength, - maxInboundPayloadSize); + maxInboundPayloadSize, + interceptors::initResponderRequestInterceptor); return wrappedRSocketRequester; }) diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index 044204225..f51c14a6d 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -33,8 +33,10 @@ import io.rsocket.keepalive.KeepAliveHandler; import io.rsocket.keepalive.KeepAliveSupport; import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.plugins.RequestInterceptor; import java.nio.channels.ClosedChannelException; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Function; import java.util.function.Supplier; import org.reactivestreams.Publisher; import org.slf4j.Logger; @@ -75,8 +77,16 @@ class RSocketRequester extends RequesterResponderSupport implements RSocket { int keepAliveTickPeriod, int keepAliveAckTimeout, @Nullable KeepAliveHandler keepAliveHandler, + Function requestInterceptorFunction, RequesterLeaseHandler leaseHandler) { - super(mtu, maxFrameLength, maxInboundPayloadSize, payloadDecoder, connection, streamIdSupplier); + super( + mtu, + maxFrameLength, + maxInboundPayloadSize, + payloadDecoder, + connection, + streamIdSupplier, + requestInterceptorFunction); this.leaseHandler = leaseHandler; this.onClose = MonoProcessor.create(); @@ -319,6 +329,10 @@ private void terminate(Throwable e) { keepAliveFramesAcceptor.dispose(); } getDuplexConnection().dispose(); + final RequestInterceptor requestInterceptor = getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.dispose(); + } leaseHandler.dispose(); synchronized (this) { diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java index 3be97760b..b8f356493 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -25,13 +25,17 @@ import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.plugins.RequestInterceptor; import java.nio.channels.ClosedChannelException; import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Function; import java.util.function.Supplier; import org.reactivestreams.Publisher; import org.slf4j.Logger; @@ -64,8 +68,16 @@ class RSocketResponder extends RequesterResponderSupport implements RSocket { ResponderLeaseHandler leaseHandler, int mtu, int maxFrameLength, - int maxInboundPayloadSize) { - super(mtu, maxFrameLength, maxInboundPayloadSize, payloadDecoder, connection, null); + int maxInboundPayloadSize, + Function requestInterceptorFunction) { + super( + mtu, + maxFrameLength, + maxInboundPayloadSize, + payloadDecoder, + connection, + null, + requestInterceptorFunction); this.requestHandler = requestHandler; @@ -93,7 +105,7 @@ private void tryTerminate(Supplier errorSupplier) { if (terminationError == null) { Throwable e = errorSupplier.get(); if (TERMINATION_ERROR.compareAndSet(this, null, e)) { - cleanup(); + doOnDispose(); } } } @@ -101,12 +113,7 @@ private void tryTerminate(Supplier errorSupplier) { @Override public Mono fireAndForget(Payload payload) { try { - if (leaseHandler.useLease()) { - return requestHandler.fireAndForget(payload); - } else { - payload.release(); - return Mono.error(leaseHandler.leaseError()); - } + return requestHandler.fireAndForget(payload); } catch (Throwable t) { return Mono.error(t); } @@ -115,12 +122,7 @@ public Mono fireAndForget(Payload payload) { @Override public Mono requestResponse(Payload payload) { try { - if (leaseHandler.useLease()) { - return requestHandler.requestResponse(payload); - } else { - payload.release(); - return Mono.error(leaseHandler.leaseError()); - } + return requestHandler.requestResponse(payload); } catch (Throwable t) { return Mono.error(t); } @@ -129,12 +131,7 @@ public Mono requestResponse(Payload payload) { @Override public Flux requestStream(Payload payload) { try { - if (leaseHandler.useLease()) { - return requestHandler.requestStream(payload); - } else { - payload.release(); - return Flux.error(leaseHandler.leaseError()); - } + return requestHandler.requestStream(payload); } catch (Throwable t) { return Flux.error(t); } @@ -143,24 +140,7 @@ public Flux requestStream(Payload payload) { @Override public Flux requestChannel(Publisher payloads) { try { - if (leaseHandler.useLease()) { - return requestHandler.requestChannel(payloads); - } else { - return Flux.error(leaseHandler.leaseError()); - } - } catch (Throwable t) { - return Flux.error(t); - } - } - - private Flux requestChannel(Payload payload, Publisher payloads) { - try { - if (leaseHandler.useLease()) { - return requestHandler.requestChannel(payloads); - } else { - payload.release(); - return Flux.error(leaseHandler.leaseError()); - } + return requestHandler.requestChannel(payloads); } catch (Throwable t) { return Flux.error(t); } @@ -190,10 +170,14 @@ public Mono onClose() { return getDuplexConnection().onClose(); } - private void cleanup() { + final void doOnDispose() { cleanUpSendingSubscriptions(); getDuplexConnection().dispose(); + final RequestInterceptor requestInterceptor = getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.dispose(); + } leaseHandlerDisposable.dispose(); requestHandler.dispose(); } @@ -203,7 +187,7 @@ private synchronized void cleanUpSendingSubscriptions() { activeStreams.clear(); } - private void handleFrame(ByteBuf frame) { + final void handleFrame(ByteBuf frame) { try { int streamId = FrameHeaderCodec.streamId(frame); FrameHandler receiver; @@ -302,70 +286,149 @@ private void handleFrame(ByteBuf frame) { } } - private void handleFireAndForget(int streamId, ByteBuf frame) { - if (FrameHeaderCodec.hasFollows(frame)) { - FireAndForgetResponderSubscriber subscriber = - new FireAndForgetResponderSubscriber(streamId, frame, this, this); + final void handleFireAndForget(int streamId, ByteBuf frame) { + if (leaseHandler.useLease()) { + + if (FrameHeaderCodec.hasFollows(frame)) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_FNF, RequestFireAndForgetFrameCodec.metadata(frame)); + } + + FireAndForgetResponderSubscriber subscriber = + new FireAndForgetResponderSubscriber(streamId, frame, this, this); - this.add(streamId, subscriber); + this.add(streamId, subscriber); + } else { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_FNF, RequestFireAndForgetFrameCodec.metadata(frame)); + + fireAndForget(super.getPayloadDecoder().apply(frame)) + .subscribe(new FireAndForgetResponderSubscriber(streamId, this)); + } else { + fireAndForget(super.getPayloadDecoder().apply(frame)) + .subscribe(FireAndForgetResponderSubscriber.INSTANCE); + } + } } else { - fireAndForget(super.getPayloadDecoder().apply(frame)) - .subscribe(FireAndForgetResponderSubscriber.INSTANCE); + final RequestInterceptor requestTracker = this.getRequestInterceptor(); + if (requestTracker != null) { + requestTracker.onReject( + leaseHandler.leaseError(), + FrameType.REQUEST_FNF, + RequestFireAndForgetFrameCodec.metadata(frame)); + } } } - private void handleRequestResponse(int streamId, ByteBuf frame) { - if (FrameHeaderCodec.hasFollows(frame)) { - RequestResponseResponderSubscriber subscriber = - new RequestResponseResponderSubscriber(streamId, frame, this, this); + final void handleRequestResponse(int streamId, ByteBuf frame) { + if (leaseHandler.useLease()) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_RESPONSE, RequestResponseFrameCodec.metadata(frame)); + } - this.add(streamId, subscriber); - } else { - RequestResponseResponderSubscriber subscriber = - new RequestResponseResponderSubscriber(streamId, this); + if (FrameHeaderCodec.hasFollows(frame)) { + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber(streamId, frame, this, this); + + this.add(streamId, subscriber); + } else { + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber(streamId, this); - if (this.add(streamId, subscriber)) { - this.requestResponse(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + if (this.add(streamId, subscriber)) { + this.requestResponse(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + } + } + } else { + final Exception leaseError = leaseHandler.leaseError(); + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onReject( + leaseError, FrameType.REQUEST_RESPONSE, RequestResponseFrameCodec.metadata(frame)); } + sendLeaseRejection(streamId, leaseError); } } - private void handleStream(int streamId, ByteBuf frame, long initialRequestN) { - if (FrameHeaderCodec.hasFollows(frame)) { - RequestStreamResponderSubscriber subscriber = - new RequestStreamResponderSubscriber(streamId, initialRequestN, frame, this, this); + final void handleStream(int streamId, ByteBuf frame, long initialRequestN) { + if (leaseHandler.useLease()) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_STREAM, RequestStreamFrameCodec.metadata(frame)); + } - this.add(streamId, subscriber); - } else { - RequestStreamResponderSubscriber subscriber = - new RequestStreamResponderSubscriber(streamId, initialRequestN, this); + if (FrameHeaderCodec.hasFollows(frame)) { + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber(streamId, initialRequestN, frame, this, this); - if (this.add(streamId, subscriber)) { - this.requestStream(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + this.add(streamId, subscriber); + } else { + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber(streamId, initialRequestN, this); + + if (this.add(streamId, subscriber)) { + this.requestStream(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + } } + } else { + final Exception leaseError = leaseHandler.leaseError(); + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onReject( + leaseError, FrameType.REQUEST_STREAM, RequestStreamFrameCodec.metadata(frame)); + } + sendLeaseRejection(streamId, leaseError); } } - private void handleChannel(int streamId, ByteBuf frame, long initialRequestN, boolean complete) { - if (FrameHeaderCodec.hasFollows(frame)) { - RequestChannelResponderSubscriber subscriber = - new RequestChannelResponderSubscriber(streamId, initialRequestN, frame, this, this); + final void handleChannel(int streamId, ByteBuf frame, long initialRequestN, boolean complete) { + if (leaseHandler.useLease()) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_CHANNEL, RequestChannelFrameCodec.metadata(frame)); + } + + if (FrameHeaderCodec.hasFollows(frame)) { + RequestChannelResponderSubscriber subscriber = + new RequestChannelResponderSubscriber(streamId, initialRequestN, frame, this, this); - this.add(streamId, subscriber); - } else { - final Payload firstPayload = super.getPayloadDecoder().apply(frame); - RequestChannelResponderSubscriber subscriber = - new RequestChannelResponderSubscriber(streamId, initialRequestN, firstPayload, this); - - if (this.add(streamId, subscriber)) { - this.requestChannel(firstPayload, subscriber).subscribe(subscriber); - if (complete) { - subscriber.handleComplete(); + this.add(streamId, subscriber); + } else { + final Payload firstPayload = super.getPayloadDecoder().apply(frame); + RequestChannelResponderSubscriber subscriber = + new RequestChannelResponderSubscriber(streamId, initialRequestN, firstPayload, this); + + if (this.add(streamId, subscriber)) { + this.requestChannel(subscriber).subscribe(subscriber); + if (complete) { + subscriber.handleComplete(); + } } } + } else { + final Exception leaseError = leaseHandler.leaseError(); + final RequestInterceptor requestTracker = this.getRequestInterceptor(); + if (requestTracker != null) { + requestTracker.onReject( + leaseError, FrameType.REQUEST_CHANNEL, RequestChannelFrameCodec.metadata(frame)); + } + sendLeaseRejection(streamId, leaseError); } } + private void sendLeaseRejection(int streamId, Throwable leaseError) { + getDuplexConnection() + .sendFrame(streamId, ErrorFrameCodec.encode(getAllocator(), streamId, leaseError)); + } + private void handleMetadataPush(Mono result) { result.subscribe(MetadataPushResponderSubscriber.INSTANCE); } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java index 258306cd2..b1c93f206 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java @@ -420,6 +420,7 @@ private Mono acceptSetup( setupPayload.keepAliveInterval(), setupPayload.keepAliveMaxLifetime(), keepAliveHandler, + interceptors::initRequesterRequestInterceptor, requesterLeaseHandler); RSocket wrappedRSocketRequester = interceptors.initRequester(rSocketRequester); @@ -451,7 +452,8 @@ private Mono acceptSetup( responderLeaseHandler, mtu, maxFrameLength, - maxInboundPayloadSize); + maxInboundPayloadSize, + interceptors::initResponderRequestInterceptor); }) .doFinally(signalType -> setupPayload.release()) .then(); diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java index 722a7c2c5..8a57820c5 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java @@ -34,12 +34,15 @@ import io.rsocket.frame.PayloadFrameCodec; import io.rsocket.frame.RequestNFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; import java.util.Objects; import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; -import reactor.core.*; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; import reactor.core.publisher.Flux; import reactor.core.publisher.Operators; import reactor.util.annotation.NonNull; @@ -59,6 +62,8 @@ final class RequestChannelRequesterFlux extends Flux final Publisher payloadsPublisher; + @Nullable final RequestInterceptor requestInterceptor; + volatile long state; static final AtomicLongFieldUpdater STATE = AtomicLongFieldUpdater.newUpdater(RequestChannelRequesterFlux.class, "state"); @@ -86,6 +91,7 @@ final class RequestChannelRequesterFlux extends Flux this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); } @Override @@ -94,8 +100,14 @@ public void subscribe(CoreSubscriber actual) { long previousState = markSubscribed(STATE, this); if (isSubscribedOrTerminated(previousState)) { - Operators.error( - actual, new IllegalStateException("RequestChannelFlux allows only a single Subscriber")); + final IllegalStateException e = + new IllegalStateException("RequestChannelFlux allows only a single Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_CHANNEL, null); + } + + Operators.error(actual, e); return; } @@ -163,13 +175,20 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { if (!isValid(mtu, this.maxFrameLength, firstPayload, true)) { lazyTerminate(STATE, this); - firstPayload.release(); this.outboundSubscription.cancel(); - this.inboundDone = true; - this.inboundSubscriber.onError( + final IllegalArgumentException e = new IllegalArgumentException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_CHANNEL, firstPayload.metadata()); + } + + firstPayload.release(); + + this.inboundDone = true; + this.inboundSubscriber.onError(e); return; } } catch (IllegalReferenceCountException e) { @@ -177,6 +196,11 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { this.outboundSubscription.cancel(); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_CHANNEL, null); + } + this.inboundDone = true; this.inboundSubscriber.onError(e); return; @@ -194,15 +218,27 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { this.inboundDone = true; final long previousState = markTerminated(STATE, this); - firstPayload.release(); this.outboundSubscription.cancel(); + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_CHANNEL, firstPayload.metadata()); + } + + firstPayload.release(); + if (!isTerminated(previousState)) { - this.inboundSubscriber.onError(Exceptions.unwrap(t)); + this.inboundSubscriber.onError(ut); } return; } + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, FrameType.REQUEST_CHANNEL, firstPayload.metadata()); + } + try { sendReleasingPayload( streamId, @@ -215,14 +251,19 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { // TODO: Should be a different flag in case of the scalar // source or if we know in advance upstream is mono false); - } catch (Throwable e) { + } catch (Throwable t) { lazyTerminate(STATE, this); sm.remove(streamId, this); this.outboundSubscription.cancel(); this.inboundDone = true; - this.inboundSubscriber.onError(e); + + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, t); + } + + this.inboundSubscriber.onError(t); return; } @@ -239,6 +280,9 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); connection.sendFrame(streamId, cancelFrame); + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId); + } return; } @@ -268,16 +312,22 @@ final void sendFollowingPayload(Payload followingPayload) { if (!isValid(mtu, this.maxFrameLength, followingPayload, true)) { followingPayload.release(); - this.cancel(); - final IllegalArgumentException e = new IllegalArgumentException( String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + if (!this.tryCancel()) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + this.propagateErrorSafely(e); return; } } catch (IllegalReferenceCountException e) { - this.cancel(); + if (!this.tryCancel()) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } this.propagateErrorSafely(e); @@ -297,41 +347,60 @@ final void sendFollowingPayload(Payload followingPayload) { allocator, true); } catch (Throwable e) { - this.cancel(); + if (!this.tryCancel()) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } this.propagateErrorSafely(e); } } - void propagateErrorSafely(Throwable e) { + void propagateErrorSafely(Throwable t) { // FIXME: must be scheduled on the connection event-loop to achieve serial // behaviour on the inbound subscriber if (!this.inboundDone) { synchronized (this) { if (!this.inboundDone) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, t); + } + this.inboundDone = true; - this.inboundSubscriber.onError(e); + this.inboundSubscriber.onError(t); } else { - Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); } } } else { - Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); } } @Override public final void cancel() { + if (!tryCancel()) { + return; + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(this.streamId); + } + } + + boolean tryCancel() { long previousState = markTerminated(STATE, this); if (isTerminated(previousState)) { - return; + return false; } this.outboundSubscription.cancel(); if (!isFirstFrameSent(previousState)) { // no need to send anything, since we have not started a stream yet (no logical wire) - return; + return false; } final int streamId = this.streamId; @@ -341,6 +410,8 @@ public final void cancel() { final ByteBuf cancelFrame = CancelFrameCodec.encode(this.allocator, streamId); this.connection.sendFrame(streamId, cancelFrame); + + return true; } @Override @@ -376,6 +447,11 @@ public void onError(Throwable t) { // FIXME: must be scheduled on the connection event-loop to achieve serial // behaviour on the inbound subscriber synchronized (this) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, t); + } + this.inboundDone = true; this.inboundSubscriber.onError(t); } @@ -405,12 +481,20 @@ public void onComplete() { final int streamId = this.streamId; - if (isInboundTerminated(previousState)) { + final boolean isInboundTerminated = isInboundTerminated(previousState); + if (isInboundTerminated) { this.requesterResponderSupport.remove(streamId, this); } final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); this.connection.sendFrame(streamId, completeFrame); + + if (isInboundTerminated) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, null); + } + } } @Override @@ -428,6 +512,11 @@ public final void handleComplete() { if (isOutboundTerminated(previousState)) { this.requesterResponderSupport.remove(this.streamId, this); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, null); + } } this.inboundSubscriber.onComplete(); @@ -443,7 +532,15 @@ public final void handleError(Throwable cause) { this.inboundDone = true; long previousState = markTerminated(STATE, this); - if (isTerminated(previousState) || isInboundTerminated(previousState)) { + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } else if (isInboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, cause); + } + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); return; } @@ -455,6 +552,12 @@ public final void handleError(Throwable cause) { this.requesterResponderSupport.remove(streamId, this); this.outboundSubscription.cancel(); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, cause); + } + this.inboundSubscriber.onError(cause); } @@ -486,11 +589,19 @@ public void handleCancel() { return; } - if (isInboundTerminated(previousState)) { + final boolean inboundTerminated = isInboundTerminated(previousState); + if (inboundTerminated) { this.requesterResponderSupport.remove(this.streamId, this); } this.outboundSubscription.cancel(); + + if (inboundTerminated) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, null); + } + } } @Override diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java index 67816407c..9d4cd5f1e 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java @@ -36,6 +36,7 @@ import io.rsocket.frame.PayloadFrameCodec; import io.rsocket.frame.RequestNFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; @@ -46,6 +47,7 @@ import reactor.core.Exceptions; import reactor.core.publisher.Flux; import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; import reactor.util.context.Context; final class RequestChannelResponderSubscriber extends Flux @@ -63,6 +65,8 @@ final class RequestChannelResponderSubscriber extends Flux final DuplexConnection connection; final long firstRequest; + @Nullable final RequestInterceptor requestInterceptor; + final RSocket handler; volatile long state; @@ -99,6 +103,7 @@ public RequestChannelResponderSubscriber( this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.handler = handler; this.firstRequest = firstRequestN; @@ -121,6 +126,7 @@ public RequestChannelResponderSubscriber( this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.firstRequest = firstRequestN; this.firstPayload = firstPayload; @@ -293,12 +299,20 @@ public void cancel() { final int streamId = this.streamId; - if (isOutboundTerminated(previousState)) { + final boolean isOutboundTerminated = isOutboundTerminated(previousState); + if (isOutboundTerminated) { this.requesterResponderSupport.remove(streamId, this); } final ByteBuf cancelFrame = CancelFrameCodec.encode(this.allocator, streamId); this.connection.sendFrame(streamId, cancelFrame); + + if (isOutboundTerminated) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, null); + } + } } @Override @@ -320,10 +334,23 @@ public final void handleCancel() { this.firstPayload = null; firstPayload.release(); } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onCancel(this.streamId); + } + return; + } + + long previousState = this.tryTerminate(true); + if (isTerminated(previousState)) { return; } - this.tryTerminate(true); + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onCancel(this.streamId); + } } final long tryTerminate(boolean isFromInbound) { @@ -434,6 +461,11 @@ public final void handleError(Throwable t) { // reached it // needs for disconnected upstream and downstream case this.outboundSubscription.cancel(); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, t); + } } @Override @@ -446,13 +478,21 @@ public void handleComplete() { long previousState = markInboundTerminated(STATE, this); - if (isOutboundTerminated(previousState)) { + final boolean isOutboundTerminated = isOutboundTerminated(previousState); + if (isOutboundTerminated) { this.requesterResponderSupport.remove(this.streamId, this); } if (isFirstFrameSent(previousState)) { this.inboundSubscriber.onComplete(); } + + if (isOutboundTerminated) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, null); + } + } } @Override @@ -468,7 +508,15 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) payload = this.payloadDecoder.apply(frame); } catch (Throwable t) { long previousState = this.tryTerminate(true); - if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, t); + } + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); return; } @@ -480,6 +528,10 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) ErrorFrameCodec.encode(this.allocator, streamId, new CanceledException(t.getMessage())); this.connection.sendFrame(streamId, errorFrame); + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, t); + } return; } @@ -514,7 +566,15 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) } long previousState = this.tryTerminate(true); - if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + if (isTerminated(previousState)) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, e); + } + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); return; } @@ -529,6 +589,11 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) new CanceledException("Failed to reassemble payload. Cause: " + e.getMessage())); this.connection.sendFrame(streamId, errorFrame); + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, e); + } + return; } } @@ -549,7 +614,15 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) ReferenceCountUtil.safeRelease(frames); previousState = this.tryTerminate(true); - if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, t); + } + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); return; } @@ -563,6 +636,11 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); this.connection.sendFrame(streamId, errorFrame); + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, t); + } + return; } @@ -591,12 +669,6 @@ public void onNext(Payload p) { final DuplexConnection connection = this.connection; final ByteBufAllocator allocator = this.allocator; - if (p == null) { - final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(allocator, streamId); - connection.sendFrame(streamId, completeFrame); - return; - } - final int mtu = this.mtu; try { if (!isValid(mtu, this.maxFrameLength, p, false)) { @@ -605,21 +677,36 @@ public void onNext(Payload p) { // FIXME: must be scheduled on the connection event-loop to achieve serial // behaviour on the inbound subscriber long previousState = this.tryTerminate(false); - if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + if (isTerminated(previousState)) { Operators.onErrorDropped( new IllegalArgumentException( String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)), this.inboundSubscriber.currentContext()); return; + } else if (isOutboundTerminated(previousState)) { + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, e); + } + + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; } - final ByteBuf errorFrame = - ErrorFrameCodec.encode( - allocator, - streamId, - new CanceledException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + final CanceledException e = + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, e); connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, e); + } return; } } catch (IllegalReferenceCountException e) { @@ -627,7 +714,15 @@ public void onNext(Payload p) { // FIXME: must be scheduled on the connection event-loop to achieve serial // behaviour on the inbound subscriber long previousState = this.tryTerminate(false); - if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + if (isTerminated(previousState)) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, e); + } + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); return; } @@ -638,6 +733,11 @@ public void onNext(Payload p) { streamId, new CanceledException("Failed to validate payload. Cause:" + e.getMessage())); connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, e); + } return; } @@ -646,7 +746,11 @@ public void onNext(Payload p) { } catch (Throwable t) { // FIXME: must be scheduled on the connection event-loop to achieve serial // behaviour on the inbound subscriber - this.tryTerminate(false); + long previousState = this.tryTerminate(false); + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null && !isTerminated(previousState)) { + interceptor.onTerminate(streamId, t); + } } } @@ -682,15 +786,13 @@ public void onError(Throwable t) { } } - if (!isFirstFrameSent(previousState)) { - if (!hasRequested(previousState)) { - final Payload firstPayload = this.firstPayload; - this.firstPayload = null; - firstPayload.release(); - } - } - - if (wasThrowableAdded && !isInboundTerminated(previousState)) { + if (!isSubscribed(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } else if (wasThrowableAdded + && isFirstFrameSent(previousState) + && !isInboundTerminated(previousState)) { Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); if (inboundError != TERMINATED) { // FIXME: must be scheduled on the connection event-loop to achieve serial @@ -705,6 +807,11 @@ public void onError(Throwable t) { final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, t); + } } @Override @@ -722,12 +829,20 @@ public void onComplete() { final int streamId = this.streamId; - if (isInboundTerminated(previousState)) { + final boolean isInboundTerminated = isInboundTerminated(previousState); + if (isInboundTerminated) { this.requesterResponderSupport.remove(streamId, this); } final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); this.connection.sendFrame(streamId, completeFrame); + + if (isInboundTerminated) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, null); + } + } } @Override diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java index 1706ece32..f3c52f648 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java @@ -30,6 +30,7 @@ import io.rsocket.frame.CancelFrameCodec; import io.rsocket.frame.FrameType; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; @@ -52,6 +53,8 @@ final class RequestResponseRequesterMono extends Mono final DuplexConnection connection; final PayloadDecoder payloadDecoder; + @Nullable final RequestInterceptor requestInterceptor; + volatile long state; static final AtomicLongFieldUpdater STATE = AtomicLongFieldUpdater.newUpdater(RequestResponseRequesterMono.class, "state"); @@ -72,6 +75,7 @@ final class RequestResponseRequesterMono extends Mono this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); } @Override @@ -79,8 +83,14 @@ public void subscribe(CoreSubscriber actual) { long previousState = markSubscribed(STATE, this); if (isSubscribedOrTerminated(previousState)) { - Operators.error( - actual, new IllegalStateException("RequestResponseMono allows only a single Subscriber")); + final IllegalStateException e = + new IllegalStateException("RequestResponseMono allows only a single " + "Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_RESPONSE, null); + } + + Operators.error(actual, e); return; } @@ -88,15 +98,28 @@ public void subscribe(CoreSubscriber actual) { try { if (!isValid(this.mtu, this.maxFrameLength, p, false)) { lazyTerminate(STATE, this); - Operators.error( - actual, + + final IllegalArgumentException e = new IllegalArgumentException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_RESPONSE, p.metadata()); + } + p.release(); + + Operators.error(actual, e); return; } } catch (IllegalReferenceCountException e) { lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_RESPONSE, null); + } + Operators.error(actual, e); return; } @@ -133,14 +156,25 @@ void sendFirstPayload(Payload payload, long initialRequestN) { this.done = true; final long previousState = markTerminated(STATE, this); + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_RESPONSE, payload.metadata()); + } + payload.release(); if (!isTerminated(previousState)) { - this.actual.onError(Exceptions.unwrap(t)); + this.actual.onError(ut); } return; } + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, FrameType.REQUEST_RESPONSE, payload.metadata()); + } + try { sendReleasingPayload( streamId, FrameType.REQUEST_RESPONSE, this.mtu, payload, connection, allocator, true); @@ -150,6 +184,10 @@ void sendFirstPayload(Payload payload, long initialRequestN) { sm.remove(streamId, this); + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, e); + } + this.actual.onError(e); return; } @@ -164,6 +202,10 @@ void sendFirstPayload(Payload payload, long initialRequestN) { final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); connection.sendFrame(streamId, cancelFrame); + + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId); + } } } @@ -181,6 +223,11 @@ public final void cancel() { ReassemblyUtils.synchronizedRelease(this, previousState); this.connection.sendFrame(streamId, CancelFrameCodec.encode(this.allocator, streamId)); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId); + } } else if (!hasRequested(previousState)) { this.payload.release(); } @@ -201,10 +248,15 @@ public final void handlePayload(Payload value) { return; } - final CoreSubscriber a = this.actual; + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); - this.requesterResponderSupport.remove(this.streamId, this); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, null); + } + final CoreSubscriber a = this.actual; a.onNext(value); a.onComplete(); } @@ -222,7 +274,13 @@ public final void handleComplete() { return; } - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, null); + } this.actual.onComplete(); } @@ -244,7 +302,13 @@ public final void handleError(Throwable cause) { ReassemblyUtils.synchronizedRelease(this, previousState); - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, cause); + } this.actual.onError(cause); } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java index f36211c7d..648afff13 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java @@ -32,6 +32,7 @@ import io.rsocket.frame.FrameType; import io.rsocket.frame.PayloadFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import org.reactivestreams.Subscription; import org.slf4j.Logger; @@ -55,9 +56,10 @@ final class RequestResponseResponderSubscriber final int maxInboundPayloadSize; final RequesterResponderSupport requesterResponderSupport; final DuplexConnection connection; - final RSocket handler; + @Nullable final RequestInterceptor requestInterceptor; + boolean done; CompositeByteBuf frames; @@ -79,7 +81,9 @@ public RequestResponseResponderSubscriber( this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.handler = handler; + this.frames = ReassemblyUtils.addFollowingFrame( allocator.compositeBuffer(), firstFrame, true, maxInboundPayloadSize); @@ -94,6 +98,7 @@ public RequestResponseResponderSubscriber( this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.payloadDecoder = null; this.handler = null; @@ -137,6 +142,11 @@ public void onNext(@Nullable Payload p) { if (p == null) { final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(allocator, streamId); connection.sendFrame(streamId, completeFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, null); + } return; } @@ -147,13 +157,16 @@ public void onNext(@Nullable Payload p) { p.release(); - final ByteBuf errorFrame = - ErrorFrameCodec.encode( - allocator, - streamId, - new CanceledException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + final CanceledException e = + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, e); connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, e); + } return; } } catch (IllegalReferenceCountException e) { @@ -165,13 +178,28 @@ public void onNext(@Nullable Payload p) { streamId, new CanceledException("Failed to validate payload. Cause" + e.getMessage())); connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, e); + } return; } try { sendReleasingPayload(streamId, FrameType.NEXT_COMPLETE, mtu, p, connection, allocator, false); - } catch (Throwable ignored) { + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, null); + } + } catch (Throwable t) { currentSubscription.cancel(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, t); + } } } @@ -197,6 +225,11 @@ public void onError(Throwable t) { final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, t); + } } @Override @@ -216,7 +249,8 @@ public void handleCancel() { // and fragmentation of the first frame was cancelled before S.lazySet(this, Operators.cancelledSubscription()); - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); final CompositeByteBuf frames = this.frames; if (frames != null) { @@ -224,6 +258,10 @@ public void handleCancel() { frames.release(); } + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId); + } return; } @@ -231,9 +269,15 @@ public void handleCancel() { return; } - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); currentSubscription.cancel(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId); + } } @Override @@ -263,6 +307,11 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) streamId, new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, t); + } return; } @@ -289,6 +338,11 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) streamId, new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, t); + } return; } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java index a3107d4d6..3608eaf52 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java @@ -31,6 +31,7 @@ import io.rsocket.frame.FrameType; import io.rsocket.frame.RequestNFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; @@ -53,6 +54,8 @@ final class RequestStreamRequesterFlux extends Flux final DuplexConnection connection; final PayloadDecoder payloadDecoder; + @Nullable final RequestInterceptor requestInterceptor; + volatile long state; static final AtomicLongFieldUpdater STATE = AtomicLongFieldUpdater.newUpdater(RequestStreamRequesterFlux.class, "state"); @@ -71,14 +74,21 @@ final class RequestStreamRequesterFlux extends Flux this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); } @Override public void subscribe(CoreSubscriber actual) { long previousState = markSubscribed(STATE, this); if (isSubscribedOrTerminated(previousState)) { - Operators.error( - actual, new IllegalStateException("RequestStreamFlux allows only a single Subscriber")); + final IllegalStateException e = + new IllegalStateException("RequestStreamFlux allows only a single Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_STREAM, null); + } + + Operators.error(actual, e); return; } @@ -86,15 +96,28 @@ public void subscribe(CoreSubscriber actual) { try { if (!isValid(this.mtu, this.maxFrameLength, p, false)) { lazyTerminate(STATE, this); - Operators.error( - actual, + + final IllegalArgumentException e = new IllegalArgumentException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_STREAM, p.metadata()); + } + p.release(); + + Operators.error(actual, e); return; } } catch (IllegalReferenceCountException e) { lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_STREAM, null); + } + Operators.error(actual, e); return; } @@ -141,14 +164,25 @@ void sendFirstPayload(Payload payload, long initialRequestN) { this.done = true; final long previousState = markTerminated(STATE, this); + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_STREAM, payload.metadata()); + } + payload.release(); if (!isTerminated(previousState)) { - this.inboundSubscriber.onError(Exceptions.unwrap(t)); + this.inboundSubscriber.onError(ut); } return; } + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, FrameType.REQUEST_STREAM, payload.metadata()); + } + try { sendReleasingPayload( streamId, @@ -159,13 +193,17 @@ void sendFirstPayload(Payload payload, long initialRequestN) { connection, allocator, false); - } catch (Throwable e) { + } catch (Throwable t) { this.done = true; lazyTerminate(STATE, this); sm.remove(streamId, this); - this.inboundSubscriber.onError(e); + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, t); + } + + this.inboundSubscriber.onError(t); return; } @@ -180,6 +218,9 @@ void sendFirstPayload(Payload payload, long initialRequestN) { final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); connection.sendFrame(streamId, cancelFrame); + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId); + } return; } @@ -215,6 +256,11 @@ public final void cancel() { ReassemblyUtils.synchronizedRelease(this, previousState); this.connection.sendFrame(streamId, CancelFrameCodec.encode(this.allocator, streamId)); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId); + } } else if (!hasRequested(previousState)) { // no need to send anything, since the first request has not happened this.payload.release(); @@ -246,6 +292,11 @@ public final void handleComplete() { this.requesterResponderSupport.remove(this.streamId, this); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, null); + } + this.inboundSubscriber.onComplete(); } @@ -268,6 +319,11 @@ public final void handleError(Throwable cause) { ReassemblyUtils.synchronizedRelease(this, previousState); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, cause); + } + this.inboundSubscriber.onError(cause); } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java index 620638d9c..6b06bc119 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java @@ -32,6 +32,7 @@ import io.rsocket.frame.FrameType; import io.rsocket.frame.PayloadFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import org.reactivestreams.Subscription; import org.slf4j.Logger; @@ -39,6 +40,7 @@ import reactor.core.CoreSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; import reactor.util.context.Context; final class RequestStreamResponderSubscriber @@ -56,6 +58,8 @@ final class RequestStreamResponderSubscriber final RequesterResponderSupport requesterResponderSupport; final DuplexConnection connection; + @Nullable final RequestInterceptor requestInterceptor; + final RSocket handler; volatile Subscription s; @@ -81,6 +85,7 @@ public RequestStreamResponderSubscriber( this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.handler = handler; this.frames = ReassemblyUtils.addFollowingFrame( @@ -97,6 +102,7 @@ public RequestStreamResponderSubscriber( this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); this.requesterResponderSupport = requesterResponderSupport; this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.payloadDecoder = null; this.handler = null; @@ -123,47 +129,77 @@ public void onNext(Payload p) { final DuplexConnection sender = this.connection; final ByteBufAllocator allocator = this.allocator; - if (p == null) { - final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(allocator, streamId); - sender.sendFrame(streamId, completeFrame); - return; - } - final int mtu = this.mtu; try { if (!isValid(mtu, this.maxFrameLength, p, false)) { p.release(); - this.handleCancel(); + if (!this.tryTerminateOnError()) { + return; + } - this.done = true; - final ByteBuf errorFrame = - ErrorFrameCodec.encode( - allocator, - streamId, - new CanceledException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + final CanceledException e = + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, e); sender.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, e); + } return; } } catch (IllegalReferenceCountException e) { - this.handleCancel(); - this.done = true; + if (!this.tryTerminateOnError()) { + return; + } + final ByteBuf errorFrame = ErrorFrameCodec.encode( allocator, streamId, new CanceledException("Failed to validate payload. Cause" + e.getMessage())); sender.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, e); + } return; } try { sendReleasingPayload(streamId, FrameType.NEXT, mtu, p, sender, allocator, false); } catch (Throwable t) { - this.handleCancel(); - this.done = true; + if (!this.tryTerminateOnError()) { + return; + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, t); + } + } + } + + boolean tryTerminateOnError() { + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription()) { + return false; + } + + this.done = true; + + if (!S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + return false; } + + this.requesterResponderSupport.remove(this.streamId, this); + + currentSubscription.cancel(); + + return true; } @Override @@ -186,11 +222,15 @@ public void onError(Throwable t) { } final int streamId = this.streamId; - this.requesterResponderSupport.remove(streamId, this); final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, t); + } } @Override @@ -206,11 +246,15 @@ public void onComplete() { } final int streamId = this.streamId; - this.requesterResponderSupport.remove(streamId, this); final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); this.connection.sendFrame(streamId, completeFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, null); + } } @Override @@ -230,7 +274,8 @@ public final void handleCancel() { // and fragmentation of the first frame was cancelled before S.lazySet(this, Operators.cancelledSubscription()); - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); final CompositeByteBuf frames = this.frames; if (frames != null) { @@ -238,6 +283,10 @@ public final void handleCancel() { frames.release(); } + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId); + } return; } @@ -245,9 +294,15 @@ public final void handleCancel() { return; } - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); currentSubscription.cancel(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId); + } } @Override @@ -260,25 +315,31 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas try { ReassemblyUtils.addFollowingFrame( frames, followingFrame, hasFollows, this.maxInboundPayloadSize); - } catch (IllegalStateException t) { + } catch (IllegalStateException e) { // if subscription is null, it means that streams has not yet reassembled all the fragments // and fragmentation of the first frame was cancelled before S.lazySet(this, Operators.cancelledSubscription()); - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); this.frames = null; frames.release(); - logger.debug("Reassembly has failed", t); - // sends error frame from the responder side to tell that something went wrong final ByteBuf errorFrame = ErrorFrameCodec.encode( this.allocator, - this.streamId, - new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + e.getMessage())); this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, e); + } + + logger.debug("Reassembly has failed", e); return; } @@ -292,19 +353,25 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas S.lazySet(this, Operators.cancelledSubscription()); this.done = true; - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); ReferenceCountUtil.safeRelease(frames); - logger.debug("Reassembly has failed", t); - // sends error frame from the responder side to tell that something went wrong final ByteBuf errorFrame = ErrorFrameCodec.encode( this.allocator, - this.streamId, + streamId, new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, t); + } + + logger.debug("Reassembly has failed", t); return; } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java b/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java index e3f70cede..2272ceb5f 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java @@ -4,7 +4,10 @@ import io.netty.util.collection.IntObjectHashMap; import io.netty.util.collection.IntObjectMap; import io.rsocket.DuplexConnection; +import io.rsocket.RSocket; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.function.Function; import reactor.util.annotation.Nullable; class RequesterResponderSupport { @@ -15,6 +18,7 @@ class RequesterResponderSupport { private final PayloadDecoder payloadDecoder; private final ByteBufAllocator allocator; private final DuplexConnection connection; + @Nullable private final RequestInterceptor requestInterceptor; @Nullable final StreamIdSupplier streamIdSupplier; final IntObjectMap activeStreams; @@ -25,7 +29,8 @@ public RequesterResponderSupport( int maxInboundPayloadSize, PayloadDecoder payloadDecoder, DuplexConnection connection, - @Nullable StreamIdSupplier streamIdSupplier) { + @Nullable StreamIdSupplier streamIdSupplier, + Function requestInterceptorFunction) { this.activeStreams = new IntObjectHashMap<>(); this.mtu = mtu; @@ -35,6 +40,7 @@ public RequesterResponderSupport( this.allocator = connection.alloc(); this.streamIdSupplier = streamIdSupplier; this.connection = connection; + this.requestInterceptor = requestInterceptorFunction.apply((RSocket) this); } public int getMtu() { @@ -61,6 +67,11 @@ public DuplexConnection getDuplexConnection() { return connection; } + @Nullable + public RequestInterceptor getRequestInterceptor() { + return requestInterceptor; + } + /** * Issues next {@code streamId} * diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java new file mode 100644 index 000000000..b4e1a1ba3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java @@ -0,0 +1,151 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import java.util.List; +import java.util.function.Function; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +class CompositeRequestInterceptor implements RequestInterceptor { + + final RequestInterceptor[] requestInterceptors; + + public CompositeRequestInterceptor(RequestInterceptor[] requestInterceptors) { + this.requestInterceptors = requestInterceptors; + } + + @Override + public void dispose() { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + requestInterceptor.dispose(); + } + } + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onStart(streamId, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onTerminate(int streamId, @Nullable Throwable cause) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onTerminate(streamId, cause); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onCancel(int streamId) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onCancel(streamId); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onReject(rejectionReason, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Nullable + static RequestInterceptor create( + RSocket rSocket, List> interceptors) { + switch (interceptors.size()) { + case 0: + return null; + case 1: + return new SafeRequestInterceptor(interceptors.get(0).apply(rSocket)); + default: + return new CompositeRequestInterceptor( + interceptors.stream().map(f -> f.apply(rSocket)).toArray(RequestInterceptor[]::new)); + } + } + + static class SafeRequestInterceptor implements RequestInterceptor { + + final RequestInterceptor requestInterceptor; + + public SafeRequestInterceptor(RequestInterceptor requestInterceptor) { + this.requestInterceptor = requestInterceptor; + } + + @Override + public void dispose() { + requestInterceptor.dispose(); + } + + @Override + public boolean isDisposed() { + return requestInterceptor.isDisposed(); + } + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + try { + requestInterceptor.onStart(streamId, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onTerminate(int streamId, @Nullable Throwable cause) { + try { + requestInterceptor.onTerminate(streamId, cause); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onCancel(int streamId) { + try { + requestInterceptor.onCancel(streamId); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + try { + requestInterceptor.onReject(rejectionReason, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java index fc032847c..be0d8278f 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java @@ -18,6 +18,7 @@ import io.rsocket.DuplexConnection; import io.rsocket.RSocket; import io.rsocket.SocketAcceptor; +import reactor.util.annotation.Nullable; /** * Extends {@link InterceptorRegistry} with methods for building a chain of registered interceptors. @@ -25,6 +26,16 @@ */ public class InitializingInterceptorRegistry extends InterceptorRegistry { + @Nullable + public RequestInterceptor initRequesterRequestInterceptor(RSocket rSocketRequester) { + return CompositeRequestInterceptor.create(rSocketRequester, getRequesterRequestInterceptors()); + } + + @Nullable + public RequestInterceptor initResponderRequestInterceptor(RSocket rSocketResponder) { + return CompositeRequestInterceptor.create(rSocketResponder, getResponderRequestInterceptors()); + } + public DuplexConnection initConnection( DuplexConnectionInterceptor.Type type, DuplexConnection connection) { for (DuplexConnectionInterceptor interceptor : getConnectionInterceptors()) { @@ -34,7 +45,7 @@ public DuplexConnection initConnection( } public RSocket initRequester(RSocket rsocket) { - for (RSocketInterceptor interceptor : getRequesterInteceptors()) { + for (RSocketInterceptor interceptor : getRequesterInterceptors()) { rsocket = interceptor.apply(rsocket); } return rsocket; diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java index 427fa15ae..0ccc4cb92 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java @@ -15,9 +15,11 @@ */ package io.rsocket.plugins; +import io.rsocket.RSocket; import java.util.ArrayList; import java.util.List; import java.util.function.Consumer; +import java.util.function.Function; /** * Provides support for registering interceptors at the following levels: @@ -30,16 +32,46 @@ * */ public class InterceptorRegistry { - private List requesterInteceptors = new ArrayList<>(); - private List responderInterceptors = new ArrayList<>(); + private List> requesterRequestInterceptors = + new ArrayList<>(); + private List> responderRequestInterceptors = + new ArrayList<>(); + private List requesterRSocketInterceptors = new ArrayList<>(); + private List responderRSocketInterceptors = new ArrayList<>(); private List socketAcceptorInterceptors = new ArrayList<>(); private List connectionInterceptors = new ArrayList<>(); + /** + * Add an {@link RequestInterceptor} that will hook into Requester RSocket requests' phases. + * + * @param interceptor a function which accepts an {@link RSocket} and returns a new {@link + * RequestInterceptor} + * @since 1.1 + */ + public InterceptorRegistry forRequester( + Function interceptor) { + requesterRequestInterceptors.add(interceptor); + return this; + } + + /** + * Add an {@link RequestInterceptor} that will hook into Requester RSocket requests' phases. + * + * @param interceptor a function which accepts an {@link RSocket} and returns a new {@link + * RequestInterceptor} + * @since 1.1 + */ + public InterceptorRegistry forResponder( + Function interceptor) { + responderRequestInterceptors.add(interceptor); + return this; + } + /** * Add an {@link RSocketInterceptor} that will decorate the RSocket used for performing requests. */ public InterceptorRegistry forRequester(RSocketInterceptor interceptor) { - requesterInteceptors.add(interceptor); + requesterRSocketInterceptors.add(interceptor); return this; } @@ -48,7 +80,7 @@ public InterceptorRegistry forRequester(RSocketInterceptor interceptor) { * registrations. */ public InterceptorRegistry forRequester(Consumer> consumer) { - consumer.accept(requesterInteceptors); + consumer.accept(requesterRSocketInterceptors); return this; } @@ -57,7 +89,7 @@ public InterceptorRegistry forRequester(Consumer> consu * requests. */ public InterceptorRegistry forResponder(RSocketInterceptor interceptor) { - responderInterceptors.add(interceptor); + responderRSocketInterceptors.add(interceptor); return this; } @@ -66,7 +98,7 @@ public InterceptorRegistry forResponder(RSocketInterceptor interceptor) { * registrations. */ public InterceptorRegistry forResponder(Consumer> consumer) { - consumer.accept(responderInterceptors); + consumer.accept(responderRSocketInterceptors); return this; } @@ -102,12 +134,20 @@ public InterceptorRegistry forConnection(Consumer getRequesterInteceptors() { - return requesterInteceptors; + List> getRequesterRequestInterceptors() { + return requesterRequestInterceptors; + } + + List> getResponderRequestInterceptors() { + return responderRequestInterceptors; + } + + List getRequesterInterceptors() { + return requesterRSocketInterceptors; } List getResponderInterceptors() { - return responderInterceptors; + return responderRSocketInterceptors; } List getConnectionInterceptors() { diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java new file mode 100644 index 000000000..5da850837 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java @@ -0,0 +1,73 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import reactor.core.Disposable; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +/** + * Class used to track the RSocket requests lifecycles. The main difference and advantage of this + * interceptor compares to {@link RSocketInterceptor} is that it allows intercepting the initial and + * terminal phases on every individual request. + * + *

Note, if any of the invocations will rise a runtime exception, this exception will be + * caught and be propagated to {@link reactor.core.publisher.Operators#onErrorDropped(Throwable, + * Context)} + * + * @since 1.1 + */ +public interface RequestInterceptor extends Disposable { + + /** + * Method which is being invoked on successful acceptance and start of a request. + * + * @param streamId used for the request + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param metadata taken from the initial frame + */ + void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata); + + /** + * Method which is being invoked once a successfully accepted request is terminated. This method + * can be invoked only after the {@link #onStart(int, FrameType, ByteBuf)} method. This method is + * exclusive with {@link #onCancel(int)}. + * + * @param streamId used by this request + * @param t with which this finished has terminated. Must be one of the following signals + */ + void onTerminate(int streamId, @Nullable Throwable t); + + /** + * Method which is being invoked once a successfully accepted request is cancelled. This method + * can be invoked only after the {@link #onStart(int, FrameType, ByteBuf)} method. This method is + * exclusive with {@link #onTerminate(int, Throwable)}. + * + * @param streamId used by this request + */ + void onCancel(int streamId); + + /** + * Method which is being invoked on the request rejection. This method is being called only if the + * actual request can not be started and is called instead of the {@link #onStart(int, FrameType, + * ByteBuf)} method. The reason for rejection can be one of the following: + * + *

+ * + *

    + *
  • No available {@link io.rsocket.lease.Lease} on the requester or the responder sides + *
  • Invalid {@link io.rsocket.Payload} size or format on the Requester side, so the request + * is being rejected before the actual streamId is generated + *
  • A second subscription on the ongoing Request + *
+ * + * @param rejectionReason exception which causes rejection of a particular request + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param metadata taken from the initial frame + */ + void onReject(Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata); +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java b/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java index d080b166d..b77e51537 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java +++ b/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java @@ -543,6 +543,7 @@ protected RSocketRequester newRSocket() { Integer.MAX_VALUE, Integer.MAX_VALUE, null, + __ -> null, RequesterLeaseHandler.None); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java b/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java index 0857a2de8..f5422a4bf 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java @@ -14,6 +14,7 @@ import io.rsocket.Payload; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.FrameType; +import io.rsocket.plugins.TestRequestInterceptor; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.ByteBufPayload; import java.time.Duration; @@ -46,7 +47,9 @@ public static void setUp() { @ParameterizedTest @MethodSource("frameSent") public void frameShouldBeSentOnSubscription(Consumer monoConsumer) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); final Payload payload = genericPayload(activeStreams.getAllocator()); final FireAndForgetRequesterMono fireAndForgetRequesterMono = new FireAndForgetRequesterMono(payload, activeStreams); @@ -62,7 +65,6 @@ public void frameShouldBeSentOnSubscription(Consumer // should not add anything to map stateAssert.isTerminated(); activeStreams.assertNoActiveStreams(); - final ByteBuf frame = activeStreams.getDuplexConnection().awaitFrame(); FrameAssert.assertThat(frame) .isNotNull() @@ -79,6 +81,10 @@ public void frameShouldBeSentOnSubscription(Consumer Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); activeStreams.getAllocator().assertHasNoLeaks(); + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectNothing(); } /** @@ -189,7 +195,9 @@ static Stream> frameSent() { @MethodSource("shouldErrorOnIncorrectRefCntInGivenPayloadSource") public void shouldErrorOnIncorrectRefCntInGivenPayload( Consumer monoConsumer) { - final TestRequesterResponderSupport streamManager = TestRequesterResponderSupport.client(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); final TestDuplexConnection sender = streamManager.getDuplexConnection(); final Payload payload = ByteBufPayload.create(""); @@ -210,6 +218,9 @@ public void shouldErrorOnIncorrectRefCntInGivenPayload( Assertions.assertThat(sender.isEmpty()).isTrue(); allocator.assertHasNoLeaks(); + testRequestInterceptor + .expectOnReject(FrameType.REQUEST_FNF, new IllegalReferenceCountException("refCnt: 0")) + .expectNothing(); } static Stream> @@ -233,7 +244,9 @@ public void shouldErrorOnIncorrectRefCntInGivenPayload( @MethodSource("shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource") public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( Consumer monoConsumer) { - final TestRequesterResponderSupport streamManager = TestRequesterResponderSupport.client(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); final TestDuplexConnection sender = streamManager.getDuplexConnection(); @@ -260,6 +273,12 @@ public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( streamManager.assertNoActiveStreams(); Assertions.assertThat(sender.isEmpty()).isTrue(); allocator.assertHasNoLeaks(); + testRequestInterceptor + .expectOnReject( + FrameType.REQUEST_FNF, + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK))) + .expectNothing(); } static Stream> @@ -289,8 +308,10 @@ public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( @ParameterizedTest @MethodSource("shouldErrorIfNoAvailabilitySource") public void shouldErrorIfNoAvailability(Consumer monoConsumer) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final RuntimeException exception = new RuntimeException("test"); final TestRequesterResponderSupport streamManager = - TestRequesterResponderSupport.client(new RuntimeException("test")); + TestRequesterResponderSupport.client(exception, testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); final TestDuplexConnection sender = streamManager.getDuplexConnection(); final Payload payload = genericPayload(allocator); @@ -311,6 +332,7 @@ public void shouldErrorIfNoAvailability(Consumer mon streamManager.assertNoActiveStreams(); Assertions.assertThat(sender.isEmpty()).isTrue(); allocator.assertHasNoLeaks(); + testRequestInterceptor.expectOnReject(FrameType.REQUEST_FNF, exception).expectNothing(); } static Stream> shouldErrorIfNoAvailabilitySource() { @@ -333,7 +355,9 @@ static Stream> shouldErrorIfNoAvailabilityS /** Ensures single subscription happens in case of racing */ @Test public void shouldSubscribeExactlyOnce1() { - final TestRequesterResponderSupport streamManager = TestRequesterResponderSupport.client(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); final TestDuplexConnection sender = streamManager.getDuplexConnection(); @@ -349,7 +373,7 @@ public void shouldSubscribeExactlyOnce1() { () -> RaceTestUtils.race( () -> { - AtomicReference atomicReference = new AtomicReference(); + AtomicReference atomicReference = new AtomicReference<>(); fireAndForgetRequesterMono.subscribe(null, atomicReference::set); Throwable throwable = atomicReference.get(); if (throwable != null) { @@ -380,6 +404,27 @@ public void shouldSubscribeExactlyOnce1() { stateAssert.isTerminated(); streamManager.assertNoActiveStreams(); + testRequestInterceptor + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .expectNothing(); } Assertions.assertThat(sender.isEmpty()).isTrue(); diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java index ae1282c1e..a36415cb1 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java @@ -33,10 +33,15 @@ import io.rsocket.RSocket; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.Exceptions; +import io.rsocket.exceptions.RejectedException; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; import io.rsocket.frame.LeaseFrameCodec; import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.frame.SetupFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.subscriber.AssertSubscriber; @@ -64,9 +69,12 @@ import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mockito; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.EmitterProcessor; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; import reactor.test.StepVerifier; class RSocketLeaseTest { @@ -107,6 +115,7 @@ void setUp() { 0, 0, null, + __ -> null, requesterLeaseHandler); mockRSocketHandler = mock(RSocket.class); @@ -144,7 +153,25 @@ void setUp() { Publisher payloadPublisher = a.getArgument(0); return Flux.from(payloadPublisher) .doOnNext(ReferenceCounted::release) - .thenMany(Flux.empty()); + .transform( + Operators.lift( + (__, actual) -> + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + actual.onSubscribe(this); + } + + @Override + protected void hookOnComplete() { + actual.onComplete(); + } + + @Override + protected void hookOnError(Throwable throwable) { + actual.onError(throwable); + } + })); }); rSocketResponder = @@ -155,7 +182,8 @@ void setUp() { responderLeaseHandler, 0, FRAME_LENGTH_MASK, - Integer.MAX_VALUE); + Integer.MAX_VALUE, + __ -> null); } @Test @@ -355,32 +383,86 @@ void requesterAvailabilityRespectsTransport() { } @ParameterizedTest - @MethodSource("interactions") - void responderMissingLeaseRequestsAreRejected( - BiFunction> interaction) { + @MethodSource("responderInteractions") + void responderMissingLeaseRequestsAreRejected(FrameType frameType) { ByteBuf buffer = byteBufAllocator.buffer(); buffer.writeCharSequence("test", CharsetUtil.UTF_8); Payload payload1 = ByteBufPayload.create(buffer); - StepVerifier.create(interaction.apply(rSocketResponder, payload1)) - .expectError(MissingLeaseException.class) - .verify(Duration.ofSeconds(5)); + switch (frameType) { + case REQUEST_FNF: + final ByteBuf fnfFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFrame(fnfFrame); + fnfFrame.release(); + break; + case REQUEST_RESPONSE: + final ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFrame(requestResponseFrame); + requestResponseFrame.release(); + break; + case REQUEST_STREAM: + final ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, 1, payload1); + rSocketResponder.handleFrame(requestStreamFrame); + requestStreamFrame.release(); + break; + case REQUEST_CHANNEL: + final ByteBuf requestChannelFrame = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, true, 1, payload1); + rSocketResponder.handleFrame(requestChannelFrame); + requestChannelFrame.release(); + break; + } + + if (frameType != REQUEST_FNF) { + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == ERROR) + .matches(bb -> Exceptions.from(1, bb) instanceof RejectedException) + .matches(ReferenceCounted::release); + } + + byteBufAllocator.assertHasNoLeaks(); } @ParameterizedTest - @MethodSource("interactions") - void responderPresentLeaseRequestsAreAccepted( - BiFunction> interaction, FrameType frameType) { + @MethodSource("responderInteractions") + void responderPresentLeaseRequestsAreAccepted(FrameType frameType) { leaseSender.onNext(Lease.create(5_000, 2)); ByteBuf buffer = byteBufAllocator.buffer(); buffer.writeCharSequence("test", CharsetUtil.UTF_8); Payload payload1 = ByteBufPayload.create(buffer); - Flux.from(interaction.apply(rSocketResponder, payload1)) - .as(StepVerifier::create) - .expectComplete() - .verify(Duration.ofSeconds(5)); + switch (frameType) { + case REQUEST_FNF: + final ByteBuf fnfFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFireAndForget(1, fnfFrame); + fnfFrame.release(); + break; + case REQUEST_RESPONSE: + final ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFrame(requestResponseFrame); + requestResponseFrame.release(); + break; + case REQUEST_STREAM: + final ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, 1, payload1); + rSocketResponder.handleFrame(requestStreamFrame); + requestStreamFrame.release(); + break; + case REQUEST_CHANNEL: + final ByteBuf requestChannelFrame = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, true, 1, payload1); + rSocketResponder.handleFrame(requestChannelFrame); + requestChannelFrame.release(); + break; + } switch (frameType) { case REQUEST_FNF: @@ -398,41 +480,113 @@ void responderPresentLeaseRequestsAreAccepted( } Assertions.assertThat(connection.getSent()) - .hasSize(1) .first() .matches(bb -> FrameHeaderCodec.frameType(bb) == LEASE) .matches(ReferenceCounted::release); + if (frameType != REQUEST_FNF) { + Assertions.assertThat(connection.getSent()) + .hasSize(2) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); + } + byteBufAllocator.assertHasNoLeaks(); } @ParameterizedTest - @MethodSource("interactions") - void responderDepletedAllowedLeaseRequestsAreRejected( - BiFunction> interaction) { + @MethodSource("responderInteractions") + void responderDepletedAllowedLeaseRequestsAreRejected(FrameType frameType) { leaseSender.onNext(Lease.create(5_000, 1)); ByteBuf buffer = byteBufAllocator.buffer(); buffer.writeCharSequence("test", CharsetUtil.UTF_8); Payload payload1 = ByteBufPayload.create(buffer); - Flux responder = Flux.from(interaction.apply(rSocketResponder, payload1)); - responder.subscribe(); + ByteBuf buffer2 = byteBufAllocator.buffer(); + buffer2.writeCharSequence("test2", CharsetUtil.UTF_8); + Payload payload2 = ByteBufPayload.create(buffer2); + + switch (frameType) { + case REQUEST_FNF: + final ByteBuf fnfFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + final ByteBuf fnfFrame2 = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, payload2); + rSocketResponder.handleFrame(fnfFrame); + rSocketResponder.handleFrame(fnfFrame2); + fnfFrame.release(); + fnfFrame2.release(); + break; + case REQUEST_RESPONSE: + final ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + final ByteBuf requestResponseFrame2 = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, payload2); + rSocketResponder.handleFrame(requestResponseFrame); + rSocketResponder.handleFrame(requestResponseFrame2); + requestResponseFrame.release(); + requestResponseFrame2.release(); + break; + case REQUEST_STREAM: + final ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, 1, payload1); + final ByteBuf requestStreamFrame2 = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, 1, payload2); + rSocketResponder.handleFrame(requestStreamFrame); + rSocketResponder.handleFrame(requestStreamFrame2); + requestStreamFrame.release(); + requestStreamFrame2.release(); + break; + case REQUEST_CHANNEL: + final ByteBuf requestChannelFrame = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, true, 1, payload1); + final ByteBuf requestChannelFrame2 = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, true, 1, payload2); + rSocketResponder.handleFrame(requestChannelFrame); + rSocketResponder.handleFrame(requestChannelFrame2); + requestChannelFrame.release(); + requestChannelFrame2.release(); + break; + } + + switch (frameType) { + case REQUEST_FNF: + Mockito.verify(mockRSocketHandler).fireAndForget(any()); + break; + case REQUEST_RESPONSE: + Mockito.verify(mockRSocketHandler).requestResponse(any()); + break; + case REQUEST_STREAM: + Mockito.verify(mockRSocketHandler).requestStream(any()); + break; + case REQUEST_CHANNEL: + Mockito.verify(mockRSocketHandler).requestChannel(any()); + break; + } Assertions.assertThat(connection.getSent()) - .hasSize(1) .first() .matches(bb -> FrameHeaderCodec.frameType(bb) == LEASE) .matches(ReferenceCounted::release); - ByteBuf buffer2 = byteBufAllocator.buffer(); - buffer2.writeCharSequence("test", CharsetUtil.UTF_8); - Payload payload2 = ByteBufPayload.create(buffer2); + if (frameType != REQUEST_FNF) { + Assertions.assertThat(connection.getSent()) + .hasSize(3) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); - Flux.from(interaction.apply(rSocketResponder, payload2)) - .as(StepVerifier::create) - .expectError(MissingLeaseException.class) - .verify(Duration.ofSeconds(5)); + Assertions.assertThat(connection.getSent()) + .hasSize(3) + .element(2) + .matches(bb -> FrameHeaderCodec.frameType(bb) == ERROR) + .matches(bb -> Exceptions.from(1, bb) instanceof RejectedException) + .matches(ReferenceCounted::release); + } + + byteBufAllocator.assertHasNoLeaks(); } @ParameterizedTest @@ -528,4 +682,12 @@ static Stream interactions() { (rSocket, payload) -> rSocket.requestChannel(Mono.just(payload)), FrameType.REQUEST_CHANNEL)); } + + static Stream responderInteractions() { + return Stream.of( + FrameType.REQUEST_FNF, + FrameType.REQUEST_RESPONSE, + FrameType.REQUEST_STREAM, + FrameType.REQUEST_CHANNEL); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java index fda6b61ee..4c7921db1 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java @@ -77,6 +77,7 @@ void setUp() { 0, 0, null, + __ -> null, RequesterLeaseHandler.None); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java index a0b3ef3f2..9aa8442d9 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -1375,6 +1375,7 @@ protected RSocketRequester newRSocket() { Integer.MAX_VALUE, Integer.MAX_VALUE, null, + (__) -> null, RequesterLeaseHandler.None); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java index 0d0b0f093..d796d45e5 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java @@ -63,6 +63,8 @@ import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.subscriber.AssertSubscriber; import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.plugins.TestRequestInterceptor; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.test.util.TestSubscriber; import io.rsocket.util.ByteBufPayload; @@ -269,15 +271,18 @@ protected void hookOnSubscribe(Subscription subscription) { @Test public void checkNoLeaksOnRacingCancelFromRequestChannelAndNextFromUpstream() { ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); for (int i = 0; i < 10000; i++) { AssertSubscriber assertSubscriber = AssertSubscriber.create(); + final MonoProcessor monoProcessor = MonoProcessor.create(); rule.setAcceptingSocket( new RSocket() { @Override public Flux requestChannel(Publisher payloads) { payloads.subscribe(assertSubscriber); - return Flux.never(); + return monoProcessor.flux(); } }, Integer.MAX_VALUE); @@ -303,19 +308,23 @@ public Flux requestChannel(Publisher payloads) { ByteBuf data3 = allocator.buffer(); data3.writeCharSequence("def3", CharsetUtil.UTF_8); ByteBuf nextFrame3 = - PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata3, data3); + PayloadFrameCodec.encode(allocator, 1, false, true, true, metadata3, data3); RaceTestUtils.race( () -> { rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3); }, - assertSubscriber::cancel); + () -> { + assertSubscriber.cancel(); + monoProcessor.onComplete(); + }); Assertions.assertThat(assertSubscriber.values()).allMatch(ReferenceCounted::release); Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnComplete(1).expectNothing(); } } @@ -323,6 +332,8 @@ public Flux requestChannel(Publisher payloads) { public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest() { Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); for (int i = 0; i < 10000; i++) { AssertSubscriber assertSubscriber = AssertSubscriber.create(); @@ -350,11 +361,13 @@ public Flux requestChannel(Publisher payloads) { sink.next(ByteBufPayload.create("d1", "m1")); sink.next(ByteBufPayload.create("d2", "m2")); sink.next(ByteBufPayload.create("d3", "m3")); + sink.complete(); }); Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnCancel(1).expectNothing(); } } @@ -363,6 +376,8 @@ public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChann Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); for (int i = 0; i < 10000; i++) { AssertSubscriber assertSubscriber = AssertSubscriber.create(); @@ -395,11 +410,12 @@ public Flux requestChannel(Publisher payloads) { sink.next(ByteBufPayload.create("d1", "m1")); sink.next(ByteBufPayload.create("d2", "m2")); sink.next(ByteBufPayload.create("d3", "m3")); + sink.complete(); }, parallel); Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); - + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnCancel(1).expectNothing(); rule.assertHasNoLeaks(); } } @@ -410,6 +426,8 @@ public Flux requestChannel(Publisher payloads) { Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); for (int i = 0; i < 10000; i++) { FluxSink[] sinks = new FluxSink[1]; AssertSubscriber assertSubscriber = AssertSubscriber.create(); @@ -499,6 +517,7 @@ public Flux requestChannel(Publisher payloads) { return msg.refCnt() == 0; }); rule.assertHasNoLeaks(); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnError(1).expectNothing(); } } @@ -507,6 +526,8 @@ public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestStrea Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); for (int i = 0; i < 10000; i++) { FluxSink[] sinks = new FluxSink[1]; @@ -536,6 +557,8 @@ public Flux requestStream(Payload payload) { Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); + + testRequestInterceptor.expectOnStart(1, REQUEST_STREAM).expectOnCancel(1).expectNothing(); } } @@ -544,6 +567,8 @@ public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestRespo Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); for (int i = 0; i < 10000; i++) { Operators.MonoSubscriber[] sources = new Operators.MonoSubscriber[1]; @@ -576,6 +601,16 @@ public void subscribe(CoreSubscriber actual) { Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_CANCEL)) + .expectNothing(); } } @@ -795,7 +830,8 @@ public Flux requestChannel(Publisher payloads) { + ERROR + "} but was {" + frameType(rule.connection.getSent().iterator().next()) - + "}"); + + "}") + .matches(ByteBuf::release); } private static Stream refCntCases() { @@ -1178,6 +1214,7 @@ public static class ServerSocketRule extends AbstractSocketRule requestInterceptor); } private void sendRequest(int streamId, FrameType frameType) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java index 38745327e..785532bcf 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java @@ -569,7 +569,8 @@ public Flux requestChannel(Publisher payloads) { ResponderLeaseHandler.None, 0, FRAME_LENGTH_MASK, - Integer.MAX_VALUE); + Integer.MAX_VALUE, + __ -> null); crs = new RSocketRequester( @@ -582,6 +583,7 @@ public Flux requestChannel(Publisher payloads) { 0, 0, null, + __ -> null, RequesterLeaseHandler.None); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java index 4b4311a00..54033e249 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java @@ -72,7 +72,7 @@ public static void setUp() { @ValueSource(strings = {"inbound", "outbound", "inboundCancel"}) public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String completionCase) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); - LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); final TestPublisher publisher = TestPublisher.create(); @@ -373,6 +373,56 @@ public void streamShouldWorkCorrectlyWhenRacingHandleErrorWithSubscription() { } } + @Test + public void streamShouldWorkCorrectlyWhenRacingOutboundErrorWithSubscription() { + RuntimeException exception = new RuntimeException("test"); + + for (int i = 0; i < 10000; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + + RaceTestUtils.race( + () -> requestChannelResponderSubscriber.subscribe(assertSubscriber), + () -> publisher.error(exception)); + + stateAssert.isTerminated(); + + FrameAssert.assertThat(activeStreams.getDuplexConnection().awaitFrame()) + .typeOf(ERROR) + .hasData("test") + .hasStreamId(1) + .hasNoLeaks(); + + if (!assertSubscriber.values().isEmpty()) { + assertSubscriber.assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(p).hasNoLeaks()); + } + + assertSubscriber + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Outbound has terminated with an error"); + + allocator.assertHasNoLeaks(); + } + } + @Test public void streamShouldWorkCorrectlyWhenRacingHandleCancelWithSubscription() { for (int i = 0; i < 10000; i++) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java index 520dd0196..1396774c4 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java @@ -30,6 +30,7 @@ import io.rsocket.frame.FrameType; import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.plugins.TestRequestInterceptor; import io.rsocket.util.ByteBufPayload; import java.time.Duration; import java.util.ArrayList; @@ -39,11 +40,10 @@ import java.util.stream.Stream; import org.assertj.core.api.Assertions; import org.assertj.core.api.Assumptions; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; -import org.junit.jupiter.params.provider.ValueSource; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Hooks; @@ -169,8 +169,9 @@ public String toString() { @MethodSource("scenarios") public void shouldSubscribeExactlyOnce(Scenario scenario) { for (int i = 0; i < 10000; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); final TestRequesterResponderSupport requesterResponderSupport = - TestRequesterResponderSupport.client(); + TestRequesterResponderSupport.client(testRequestInterceptor); final Supplier payloadSupplier = () -> TestRequesterResponderSupport.genericPayload( @@ -214,6 +215,9 @@ public void shouldSubscribeExactlyOnce(Scenario scenario) { if (requestOperator instanceof FrameHandler) { ((FrameHandler) requestOperator).handleComplete(); + if (scenario.requestType() == REQUEST_CHANNEL) { + ((FrameHandler) requestOperator).handleCancel(); + } } }) .thenCancel() @@ -240,6 +244,29 @@ public void shouldSubscribeExactlyOnce(Scenario scenario) { stepVerifier.verify(Duration.ofSeconds(1)); requesterResponderSupport.getAllocator().assertHasNoLeaks(); + if (scenario.requestType() != METADATA_PUSH) { + testRequestInterceptor + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .expectNothing(); + } } } @@ -251,7 +278,9 @@ public void shouldSentRequestFrameOnceInCaseOfRequestRacing(Scenario scenario) { .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); for (int i = 0; i < 10000; i++) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); final Supplier payloadSupplier = () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); @@ -316,6 +345,12 @@ public void shouldSentRequestFrameOnceInCaseOfRequestRacing(Scenario scenario) { activeStreams.assertNoActiveStreams(); Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); activeStreams.getAllocator().assertHasNoLeaks(); + if (scenario.requestType() != METADATA_PUSH) { + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + } } } @@ -330,7 +365,9 @@ public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(Scenario scenario) { .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); for (int i = 0; i < 10000; i++) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); final Supplier payloadSupplier = () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); @@ -404,6 +441,16 @@ public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(Scenario scenario) { .hasClientSideStreamId() .hasStreamId(1) .hasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } else { + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); } Assertions.assertThat(responsePayload.release()).isTrue(); @@ -419,22 +466,24 @@ public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(Scenario scenario) { * Ensures that in case of racing between next element and cancel we will not have any memory * leaks */ - @Test - public void shouldHaveNoLeaksOnNextAndCancelRacing() { + @ParameterizedTest(name = "Should have no leaks when {0} is canceled during reassembly") + @MethodSource("scenarios") + public void shouldHaveNoLeaksOnNextAndCancelRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + for (int i = 0; i < 10000; i++) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); - final Payload payload = - TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); - final RequestResponseRequesterMono requestResponseRequesterMono = - new RequestResponseRequesterMono(payload, activeStreams); + final Publisher requestOperator = scenario.requestOperator(payloadSupplier, activeStreams); Payload response = ByteBufPayload.create("test", "test"); - - StepVerifier.create(requestResponseRequesterMono.doOnNext(Payload::release)) - .expectSubscription() - .expectComplete() - .verifyLater(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + requestOperator.subscribe((AssertSubscriber) assertSubscriber); final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); FrameAssert.assertThat(sentFrame) @@ -446,16 +495,16 @@ public void shouldHaveNoLeaksOnNextAndCancelRacing() { .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) .hasData(TestRequesterResponderSupport.DATA_CONTENT) .hasNoFragmentsFollow() - .typeOf(FrameType.REQUEST_RESPONSE) + .typeOf(scenario.requestType()) .hasClientSideStreamId() .hasStreamId(1) .hasNoLeaks(); RaceTestUtils.race( - requestResponseRequesterMono::cancel, - () -> requestResponseRequesterMono.handlePayload(response)); + ((Subscription) requestOperator)::cancel, + () -> ((RequesterFrameHandler) requestOperator).handlePayload(response)); - Assertions.assertThat(payload.refCnt()).isZero(); + assertSubscriber.values().forEach(Payload::release); Assertions.assertThat(response.refCnt()).isZero(); activeStreams.assertNoActiveStreams(); @@ -468,10 +517,19 @@ public void shouldHaveNoLeaksOnNextAndCancelRacing() { .hasClientSideStreamId() .hasStreamId(1) .hasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } else { + assertSubscriber.assertTerminated(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); } Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); - - StateAssert.assertThat(requestResponseRequesterMono).isTerminated(); activeStreams.getAllocator().assertHasNoLeaks(); } } @@ -482,84 +540,106 @@ public void shouldHaveNoLeaksOnNextAndCancelRacing() { * cancel we will not have any memory leaks */ @ParameterizedTest - @ValueSource(booleans = {false, true}) - public void shouldHaveNoUnexpectedErrorDuringOnErrorAndCancelRacing(boolean withReassembly) { + @MethodSource("scenarios") + public void shouldHaveNoUnexpectedErrorDuringOnErrorAndCancelRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + boolean[] withReassemblyOptions = new boolean[] {true, false}; final ArrayList droppedErrors = new ArrayList<>(); Hooks.onErrorDropped(droppedErrors::add); - try { - for (int i = 0; i < 10000; i++) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); - final Payload payload = - TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); - - final RequestResponseRequesterMono requestResponseRequesterMono = - new RequestResponseRequesterMono(payload, activeStreams); - - final StateAssert stateAssert = - StateAssert.assertThat(requestResponseRequesterMono); - - stateAssert.isUnsubscribed(); - final AssertSubscriber assertSubscriber = - requestResponseRequesterMono.subscribeWith(AssertSubscriber.create(0)); - stateAssert.hasSubscribedFlagOnly(); - - assertSubscriber.request(1); - - stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + try { + for (boolean withReassembly : withReassemblyOptions) { + for (int i = 0; i < 10000; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = + scenario.requestOperator(payloadSupplier, activeStreams); + + final StateAssert stateAssert; + if (requestOperator instanceof RequestResponseRequesterMono) { + stateAssert = StateAssert.assertThat((RequestResponseRequesterMono) requestOperator); + } else if (requestOperator instanceof RequestStreamRequesterFlux) { + stateAssert = StateAssert.assertThat((RequestStreamRequesterFlux) requestOperator); + } else { + stateAssert = StateAssert.assertThat((RequestChannelRequesterFlux) requestOperator); + } - final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); - FrameAssert.assertThat(sentFrame) - .isNotNull() - .hasPayloadSize( - TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length - + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) - .length) - .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) - .hasData(TestRequesterResponderSupport.DATA_CONTENT) - .hasNoFragmentsFollow() - .typeOf(FrameType.REQUEST_RESPONSE) - .hasClientSideStreamId() - .hasStreamId(1) - .hasNoLeaks(); + stateAssert.isUnsubscribed(); + final AssertSubscriber assertSubscriber = AssertSubscriber.create(0); - if (withReassembly) { - final ByteBuf fragmentBuf = - activeStreams.getAllocator().buffer().writeBytes(new byte[] {1, 2, 3}); - requestResponseRequesterMono.handleNext(fragmentBuf, true, false); - // mimic frameHandler behaviour - fragmentBuf.release(); - } + requestOperator.subscribe((AssertSubscriber) assertSubscriber); - final RuntimeException testException = new RuntimeException("test"); - RaceTestUtils.race( - requestResponseRequesterMono::cancel, - () -> requestResponseRequesterMono.handleError(testException)); + stateAssert.hasSubscribedFlagOnly(); - Assertions.assertThat(payload.refCnt()).isZero(); + assertSubscriber.request(1); - activeStreams.assertNoActiveStreams(); - stateAssert.isTerminated(); + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); - final boolean isEmpty = activeStreams.getDuplexConnection().isEmpty(); - if (!isEmpty) { - final ByteBuf cancellationFrame = activeStreams.getDuplexConnection().awaitFrame(); - FrameAssert.assertThat(cancellationFrame) + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(sentFrame) .isNotNull() - .typeOf(FrameType.CANCEL) + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) .hasClientSideStreamId() .hasStreamId(1) .hasNoLeaks(); - Assertions.assertThat(droppedErrors).containsExactly(testException); - } else { - assertSubscriber.assertTerminated().assertErrorMessage("test"); - } - Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + if (withReassembly) { + final ByteBuf fragmentBuf = + activeStreams.getAllocator().buffer().writeBytes(new byte[] {1, 2, 3}); + ((RequesterFrameHandler) requestOperator).handleNext(fragmentBuf, true, false); + // mimic frameHandler behaviour + fragmentBuf.release(); + } - stateAssert.isTerminated(); - droppedErrors.clear(); - activeStreams.getAllocator().assertHasNoLeaks(); + final RuntimeException testException = new RuntimeException("test"); + RaceTestUtils.race( + ((Subscription) requestOperator)::cancel, + () -> ((RequesterFrameHandler) requestOperator).handleError(testException)); + + activeStreams.assertNoActiveStreams(); + stateAssert.isTerminated(); + + final boolean isEmpty = activeStreams.getDuplexConnection().isEmpty(); + if (!isEmpty) { + final ByteBuf cancellationFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(cancellationFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(droppedErrors).containsExactly(testException); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } else { + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnError(1) + .expectNothing(); + + assertSubscriber.assertTerminated().assertErrorMessage("test"); + } + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + + stateAssert.isTerminated(); + droppedErrors.clear(); + activeStreams.getAllocator().assertHasNoLeaks(); + } } } finally { Hooks.resetOnErrorDropped(); @@ -583,20 +663,25 @@ public void shouldHaveNoUnexpectedErrorDuringOnErrorAndCancelRacing(boolean with * *

Ensures full serialization of outgoing signal (frames) */ - @Test - public void shouldBeConsistentInCaseOfRacingOfCancellationAndRequest() { + @ParameterizedTest + @MethodSource("scenarios") + public void shouldBeConsistentInCaseOfRacingOfCancellationAndRequest(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); for (int i = 0; i < 10000; i++) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); - final Payload payload = - TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); - final RequestResponseRequesterMono requestResponseRequesterMono = - new RequestResponseRequesterMono(payload, activeStreams); + final Publisher requestOperator = scenario.requestOperator(payloadSupplier, activeStreams); Payload response = ByteBufPayload.create("test", "test"); - final AssertSubscriber assertSubscriber = - requestResponseRequesterMono.subscribeWith(new AssertSubscriber<>(0)); + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + requestOperator.subscribe((AssertSubscriber) assertSubscriber); RaceTestUtils.race(() -> assertSubscriber.cancel(), () -> assertSubscriber.request(1)); @@ -604,11 +689,7 @@ public void shouldBeConsistentInCaseOfRacingOfCancellationAndRequest() { final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); FrameAssert.assertThat(sentFrame) .isNotNull() - .typeOf(FrameType.REQUEST_RESPONSE) - .hasPayloadSize( - TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length - + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) - .length) + .typeOf(scenario.requestType()) .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) .hasData(TestRequesterResponderSupport.DATA_CONTENT) .hasNoFragmentsFollow() @@ -623,15 +704,17 @@ public void shouldBeConsistentInCaseOfRacingOfCancellationAndRequest() { .hasClientSideStreamId() .hasStreamId(1) .hasNoLeaks(); - } - Assertions.assertThat(payload.refCnt()).isZero(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } - StateAssert.assertThat(requestResponseRequesterMono).isTerminated(); + ((RequesterFrameHandler) requestOperator).handlePayload(response); + assertSubscriber.values().forEach(Payload::release); - requestResponseRequesterMono.handlePayload(response); Assertions.assertThat(response.refCnt()).isZero(); - activeStreams.assertNoActiveStreams(); Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); activeStreams.getAllocator().assertHasNoLeaks(); @@ -639,20 +722,26 @@ public void shouldBeConsistentInCaseOfRacingOfCancellationAndRequest() { } /** Ensures that CancelFrame is sent exactly once in case of racing between cancel() methods */ - @Test - public void shouldSentCancelFrameExactlyOnce() { + @ParameterizedTest + @MethodSource("scenarios") + public void shouldSentCancelFrameExactlyOnce(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); for (int i = 0; i < 10000; i++) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); - final Payload payload = - TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); - final RequestResponseRequesterMono requestResponseRequesterMono = - new RequestResponseRequesterMono(payload, activeStreams); + final Publisher requesterOperator = + scenario.requestOperator(payloadSupplier, activeStreams); Payload response = ByteBufPayload.create("test", "test"); - final AssertSubscriber assertSubscriber = - requestResponseRequesterMono.subscribeWith(new AssertSubscriber<>(0)); + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + requesterOperator.subscribe((AssertSubscriber) assertSubscriber); assertSubscriber.request(1); @@ -660,19 +749,15 @@ public void shouldSentCancelFrameExactlyOnce() { FrameAssert.assertThat(sentFrame) .isNotNull() .hasNoFragmentsFollow() - .typeOf(FrameType.REQUEST_RESPONSE) + .typeOf(scenario.requestType()) .hasClientSideStreamId() - .hasPayloadSize( - TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length - + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) - .length) .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) .hasData(TestRequesterResponderSupport.DATA_CONTENT) .hasStreamId(1) .hasNoLeaks(); RaceTestUtils.race( - requestResponseRequesterMono::cancel, requestResponseRequesterMono::cancel); + ((Subscription) requesterOperator)::cancel, ((Subscription) requesterOperator)::cancel); final ByteBuf cancelFrame = activeStreams.getDuplexConnection().awaitFrame(); FrameAssert.assertThat(cancelFrame) @@ -682,15 +767,18 @@ public void shouldSentCancelFrameExactlyOnce() { .hasStreamId(1) .hasNoLeaks(); - Assertions.assertThat(payload.refCnt()).isZero(); - activeStreams.assertNoActiveStreams(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); - StateAssert.assertThat(requestResponseRequesterMono).isTerminated(); + activeStreams.assertNoActiveStreams(); - requestResponseRequesterMono.handlePayload(response); + ((RequesterFrameHandler) requesterOperator).handlePayload(response); + assertSubscriber.values().forEach(Payload::release); Assertions.assertThat(response.refCnt()).isZero(); - requestResponseRequesterMono.handleComplete(); + ((RequesterFrameHandler) requesterOperator).handleComplete(); assertSubscriber.assertNotTerminated(); activeStreams.assertNoActiveStreams(); diff --git a/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java b/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java index 270bc4a05..4f7821e4a 100755 --- a/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java @@ -20,6 +20,7 @@ import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; import static io.rsocket.frame.FrameType.REQUEST_FNF; import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; import io.netty.buffer.ByteBuf; import io.rsocket.FrameAssert; @@ -29,6 +30,8 @@ import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.FrameType; import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.plugins.TestRequestInterceptor; import io.rsocket.test.util.TestDuplexConnection; import java.util.ArrayList; import java.util.concurrent.ThreadLocalRandom; @@ -86,6 +89,12 @@ public ResponderFrameHandler responseOperator( new RequestResponseResponderSubscriber( streamId, firstFragment, streamManager, handler); streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_RESPONSE, null); + } + return subscriber; } @@ -99,6 +108,12 @@ public ResponderFrameHandler responseOperator( RequestResponseResponderSubscriber subscriber = new RequestResponseResponderSubscriber(streamId, streamManager); streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_RESPONSE, null); + } + return handler.requestResponse(firstPayload).subscribeWith(subscriber); } @@ -128,6 +143,12 @@ public ResponderFrameHandler responseOperator( RequestStreamResponderSubscriber subscriber = new RequestStreamResponderSubscriber( streamId, initialRequestN, firstFragment, streamManager, handler); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_STREAM, null); + } + streamManager.activeStreams.put(streamId, subscriber); return subscriber; } @@ -142,6 +163,12 @@ public ResponderFrameHandler responseOperator( RequestStreamResponderSubscriber subscriber = new RequestStreamResponderSubscriber(streamId, initialRequestN, streamManager); streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_STREAM, null); + } + return handler.requestStream(firstPayload).subscribeWith(subscriber); } @@ -172,6 +199,12 @@ public ResponderFrameHandler responseOperator( new RequestChannelResponderSubscriber( streamId, initialRequestN, firstFragment, streamManager, handler); streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_CHANNEL, null); + } + return subscriber; } @@ -186,6 +219,12 @@ public ResponderFrameHandler responseOperator( new RequestChannelResponderSubscriber( streamId, initialRequestN, firstPayload, streamManager); streamManager.activeStreams.put(streamId, responderSubscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_CHANNEL, null); + } + return handler.requestChannel(responderSubscriber).subscribeWith(responderSubscriber); } @@ -242,8 +281,9 @@ public Flux requestChannel(Publisher payloads) { void shouldHandleRequest(Scenario scenario) { Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); TestRequesterResponderSupport testRequesterResponderSupport = - TestRequesterResponderSupport.client(); + TestRequesterResponderSupport.client(testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); final TestDuplexConnection sender = testRequesterResponderSupport.getDuplexConnection(); TestPublisher testPublisher = TestPublisher.create(); @@ -286,6 +326,9 @@ void shouldHandleRequest(Scenario scenario) { .hasStreamId(1) .hasRequestN(1) .hasNoLeaks(); + + responderFrameHandler.handleComplete(); + testHandler.consumer.assertComplete(); } } @@ -294,6 +337,10 @@ void shouldHandleRequest(Scenario scenario) { .assertValueCount(1) .assertValuesWith(p -> PayloadAssert.assertThat(p).hasNoLeaks()); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); allocator.assertHasNoLeaks(); } @@ -302,8 +349,9 @@ void shouldHandleRequest(Scenario scenario) { void shouldHandleFragmentedRequest(Scenario scenario) { Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); TestRequesterResponderSupport testRequesterResponderSupport = - TestRequesterResponderSupport.client(); + TestRequesterResponderSupport.client(testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); final TestDuplexConnection sender = testRequesterResponderSupport.getDuplexConnection(); TestPublisher testPublisher = TestPublisher.create(); @@ -370,6 +418,11 @@ void shouldHandleFragmentedRequest(Scenario scenario) { firstPayload.release(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + allocator.assertHasNoLeaks(); } @@ -378,8 +431,9 @@ void shouldHandleFragmentedRequest(Scenario scenario) { void shouldHandleInterruptedFragmentation(Scenario scenario) { Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); TestRequesterResponderSupport testRequesterResponderSupport = - TestRequesterResponderSupport.client(); + TestRequesterResponderSupport.client(testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); TestPublisher testPublisher = TestPublisher.create(); TestHandler testHandler = new TestHandler(testPublisher, new AssertSubscriber<>(0)); @@ -413,6 +467,11 @@ void shouldHandleInterruptedFragmentation(Scenario scenario) { testPublisher.assertWasNotSubscribed(); testRequesterResponderSupport.assertNoActiveStreams(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + allocator.assertHasNoLeaks(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java index b96139fb5..fe3f75e1b 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java @@ -61,6 +61,7 @@ void requesterStreamsTerminatedOnZeroErrorFrame() { 0, 0, null, + __ -> null, RequesterLeaseHandler.None); String errorMsg = "error"; @@ -98,6 +99,7 @@ void requesterNewStreamsTerminatedAfterZeroErrorFrame() { 0, 0, null, + __ -> null, RequesterLeaseHandler.None); conn.addToReceivedBuffer( diff --git a/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java b/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java index f81e8a610..e282d72d5 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java +++ b/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java @@ -23,9 +23,11 @@ import io.netty.util.CharsetUtil; import io.rsocket.DuplexConnection; import io.rsocket.Payload; +import io.rsocket.RSocket; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.FrameType; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.ByteBufPayload; import java.util.ArrayList; @@ -34,7 +36,7 @@ import reactor.core.Exceptions; import reactor.util.annotation.Nullable; -final class TestRequesterResponderSupport extends RequesterResponderSupport { +final class TestRequesterResponderSupport extends RequesterResponderSupport implements RSocket { static final String DATA_CONTENT = "testData"; static final String METADATA_CONTENT = "testMetadata"; @@ -47,14 +49,16 @@ final class TestRequesterResponderSupport extends RequesterResponderSupport { DuplexConnection connection, int mtu, int maxFrameLength, - int maxInboundPayloadSize) { + int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor) { super( mtu, maxFrameLength, maxInboundPayloadSize, PayloadDecoder.ZERO_COPY, connection, - streamIdSupplier); + streamIdSupplier, + (__) -> requestInterceptor); this.error = error; } @@ -170,6 +174,18 @@ public synchronized int addAndGetNextStreamId(FrameHandler frameHandler) { return nextStreamId; } + public static TestRequesterResponderSupport client( + @Nullable Throwable e, @Nullable RequestInterceptor requestInterceptor) { + return client( + new TestDuplexConnection( + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT)), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + requestInterceptor, + e); + } + public static TestRequesterResponderSupport client(@Nullable Throwable e) { return client(0, FRAME_LENGTH_MASK, Integer.MAX_VALUE, e); } @@ -182,14 +198,34 @@ public static TestRequesterResponderSupport client( mtu, maxFrameLength, maxInboundPayloadSize, + null, e); } + public static TestRequesterResponderSupport client( + TestDuplexConnection duplexConnection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize) { + return client(duplexConnection, mtu, maxFrameLength, maxInboundPayloadSize, null); + } + public static TestRequesterResponderSupport client( TestDuplexConnection duplexConnection, int mtu, int maxFrameLength, int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor) { + return client( + duplexConnection, mtu, maxFrameLength, maxInboundPayloadSize, requestInterceptor, null); + } + + public static TestRequesterResponderSupport client( + TestDuplexConnection duplexConnection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor, @Nullable Throwable e) { return new TestRequesterResponderSupport( e, @@ -197,7 +233,8 @@ public static TestRequesterResponderSupport client( duplexConnection, mtu, maxFrameLength, - maxInboundPayloadSize); + maxInboundPayloadSize, + requestInterceptor); } public static TestRequesterResponderSupport client( @@ -217,6 +254,16 @@ public static TestRequesterResponderSupport client() { return client(0); } + public static TestRequesterResponderSupport client(RequestInterceptor requestInterceptor) { + return client( + new TestDuplexConnection( + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT)), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + requestInterceptor); + } + public TestRequesterResponderSupport assertNoActiveStreams() { Assertions.assertThat(activeStreams).isEmpty(); return this; diff --git a/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java b/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java new file mode 100644 index 000000000..24a035b78 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java @@ -0,0 +1,754 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.FrameType; +import io.rsocket.transport.local.LocalClientTransport; +import io.rsocket.transport.local.LocalServerTransport; +import io.rsocket.util.DefaultPayload; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.util.annotation.Nullable; + +public class RequestInterceptorTest { + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheClientRequesterSide(boolean errorOutcome) { + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .bindNow(LocalServerTransport.create("test")); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final RSocket rSocket = + RSocketConnector.create() + .interceptors( + ir -> + ir.forRequester( + (Function) + (__) -> testRequestInterceptor)) + .connect(LocalClientTransport.create("test")) + .block(); + + try { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel(Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + } finally { + rSocket.dispose(); + closeable.dispose(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheClientResponderSide(boolean errorOutcome) + throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + final Closeable closeable = + RSocketServer.create( + (setup, rSocket) -> + Mono.just(new RSocket() {}) + .doAfterTerminate( + () -> { + new Thread( + () -> { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel( + Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + latch.countDown(); + }) + .start(); + })) + .bindNow(LocalServerTransport.create("test")); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final RSocket rSocket = + RSocketConnector.create() + .acceptor( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .interceptors( + ir -> + ir.forResponder( + (Function) + (__) -> testRequestInterceptor)) + .connect(LocalClientTransport.create("test")) + .block(); + + try { + Assertions.assertThat(latch.await(1, TimeUnit.SECONDS)).isTrue(); + + testRequestInterceptor + .expectOnStart(2, FrameType.REQUEST_FNF) + .expectOnComplete(2) + .expectOnStart(4, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 4) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(6, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 6) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(8, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 8) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + + } finally { + rSocket.dispose(); + closeable.dispose(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheServerRequesterSide(boolean errorOutcome) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .interceptors( + ir -> + ir.forResponder( + (Function) + (__) -> testRequestInterceptor)) + .bindNow(LocalServerTransport.create("test")); + final RSocket rSocket = + RSocketConnector.create().connect(LocalClientTransport.create("test")).block(); + + try { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel(Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + } finally { + rSocket.dispose(); + closeable.dispose(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheServerResponderSide(boolean errorOutcome) + throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final Closeable closeable = + RSocketServer.create( + (setup, rSocket) -> + Mono.just(new RSocket() {}) + .doAfterTerminate( + () -> { + new Thread( + () -> { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel( + Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + latch.countDown(); + }) + .start(); + })) + .interceptors( + ir -> + ir.forRequester( + (Function) + (__) -> testRequestInterceptor)) + .bindNow(LocalServerTransport.create("test")); + final RSocket rSocket = + RSocketConnector.create() + .acceptor( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .connect(LocalClientTransport.create("test")) + .block(); + + try { + Assertions.assertThat(latch.await(1, TimeUnit.SECONDS)).isTrue(); + + testRequestInterceptor + .expectOnStart(2, FrameType.REQUEST_FNF) + .expectOnComplete(2) + .expectOnStart(4, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 4) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(6, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 6) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(8, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 8) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + + } finally { + rSocket.dispose(); + closeable.dispose(); + } + } + + @Test + void ensuresExceptionInTheInterceptorIsHandledProperly() { + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads); + } + })) + .bindNow(LocalServerTransport.create("test")); + + final RequestInterceptor testRequestInterceptor = + new RequestInterceptor() { + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnStart"); + } + + @Override + public void onTerminate(int streamId, @Nullable Throwable terminalSignal) { + throw new RuntimeException("testOnTerminate"); + } + + @Override + public void onCancel(int streamId) { + throw new RuntimeException("testOnCancel"); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnReject"); + } + + @Override + public void dispose() {} + }; + final RSocket rSocket = + RSocketConnector.create() + .interceptors( + ir -> + ir.forRequester( + (Function) + (__) -> testRequestInterceptor)) + .connect(LocalClientTransport.create("test")) + .block(); + + try { + StepVerifier.create(rSocket.fireAndForget(DefaultPayload.create("test"))) + .expectSubscription() + .expectComplete() + .verify(); + + StepVerifier.create(rSocket.requestResponse(DefaultPayload.create("test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(); + + StepVerifier.create(rSocket.requestStream(DefaultPayload.create("test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(); + + StepVerifier.create(rSocket.requestChannel(Flux.just(DefaultPayload.create("test")))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(); + } finally { + rSocket.dispose(); + closeable.dispose(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void shouldSupportMultipleInterceptors(boolean errorOutcome) { + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .bindNow(LocalServerTransport.create("test")); + + final RequestInterceptor testRequestInterceptor1 = + new RequestInterceptor() { + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnStart"); + } + + @Override + public void onTerminate(int streamId, @Nullable Throwable terminalSignal) { + throw new RuntimeException("testOnTerminate"); + } + + @Override + public void onCancel(int streamId) { + throw new RuntimeException("testOnTerminate"); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnReject"); + } + + @Override + public void dispose() {} + }; + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequestInterceptor testRequestInterceptor2 = new TestRequestInterceptor(); + final RSocket rSocket = + RSocketConnector.create() + .interceptors( + ir -> + ir.forRequester( + (Function) + (__) -> testRequestInterceptor) + .forRequester( + (Function) + (__) -> testRequestInterceptor1) + .forRequester( + (Function) + (__) -> testRequestInterceptor2)) + .connect(LocalClientTransport.create("test")) + .block(); + + try { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel(Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + + testRequestInterceptor2 + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + } finally { + rSocket.dispose(); + closeable.dispose(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java b/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java new file mode 100644 index 000000000..fe9de7ce1 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java @@ -0,0 +1,141 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.jctools.queues.MpscUnboundedArrayQueue; +import java.util.Queue; +import java.util.function.Consumer; +import org.assertj.core.api.Assertions; +import org.assertj.core.api.Condition; +import reactor.util.annotation.Nullable; + +public class TestRequestInterceptor implements RequestInterceptor { + + final Queue events = new MpscUnboundedArrayQueue<>(128); + + @Override + public void dispose() {} + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + events.add(new Event(EventType.ON_START, streamId, requestType, null)); + } + + @Override + public void onTerminate(int streamId, @Nullable Throwable t) { + events.add( + new Event(t == null ? EventType.ON_COMPLETE : EventType.ON_ERROR, streamId, null, t)); + } + + @Override + public void onCancel(int streamId) { + events.add(new Event(EventType.ON_CANCEL, streamId, null, null)); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + events.add(new Event(EventType.ON_REJECT, -1, requestType, rejectionReason)); + } + + public TestRequestInterceptor expectOnStart(int streamId, FrameType requestType) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_START) + .hasFieldOrPropertyWithValue("streamId", streamId) + .hasFieldOrPropertyWithValue("requestType", requestType); + + return this; + } + + public TestRequestInterceptor expectOnComplete(int streamId) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_COMPLETE) + .hasFieldOrPropertyWithValue("streamId", streamId); + + return this; + } + + public TestRequestInterceptor expectOnError(int streamId) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_ERROR) + .hasFieldOrPropertyWithValue("streamId", streamId); + + return this; + } + + public TestRequestInterceptor expectOnCancel(int streamId) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_CANCEL) + .hasFieldOrPropertyWithValue("streamId", streamId); + + return this; + } + + public TestRequestInterceptor assertNext(Consumer consumer) { + final Event event = events.poll(); + Assertions.assertThat(event).isNotNull(); + + consumer.accept(event); + + return this; + } + + public TestRequestInterceptor expectOnReject(FrameType requestType, Throwable rejectionReason) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_REJECT) + .has( + new Condition<>( + e -> { + Assertions.assertThat(e.error) + .isExactlyInstanceOf(rejectionReason.getClass()) + .hasMessage(rejectionReason.getMessage()) + .hasCause(rejectionReason.getCause()); + return true; + }, + "Has rejection reason which matches to %s", + rejectionReason)) + .hasFieldOrPropertyWithValue("requestType", requestType); + + return this; + } + + public TestRequestInterceptor expectNothing() { + final Event event = events.poll(); + + Assertions.assertThat(event).isNull(); + + return this; + } + + public static final class Event { + public final EventType eventType; + public final int streamId; + public final FrameType requestType; + public final Throwable error; + + Event(EventType eventType, int streamId, FrameType requestType, Throwable error) { + this.eventType = eventType; + this.streamId = streamId; + this.requestType = requestType; + this.error = error; + } + } + + public enum EventType { + ON_START, + ON_COMPLETE, + ON_ERROR, + ON_CANCEL, + ON_REJECT + } +}