From 909cb4f5ed6a9498d0b0d1737f9ecef3888ba176 Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Thu, 22 Oct 2020 16:41:27 +0300 Subject: [PATCH 1/2] improves RequestInterceptor API to have FrameType in all the calls Signed-off-by: Oleh Dokuka --- .../core/FireAndForgetRequesterMono.java | 10 +++--- .../FireAndForgetResponderSubscriber.java | 11 +++--- .../core/RequestChannelRequesterFlux.java | 20 +++++------ .../RequestChannelResponderSubscriber.java | 36 +++++++++---------- .../core/RequestResponseRequesterMono.java | 12 +++---- .../RequestResponseResponderSubscriber.java | 20 +++++------ .../core/RequestStreamRequesterFlux.java | 16 +++++---- .../RequestStreamResponderSubscriber.java | 18 +++++----- .../plugins/CompositeRequestInterceptor.java | 16 ++++----- .../rsocket/plugins/RequestInterceptor.java | 14 +++++--- .../plugins/RequestInterceptorTest.java | 10 +++--- .../plugins/TestRequestInterceptor.java | 9 ++--- 12 files changed, 102 insertions(+), 90 deletions(-) diff --git a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java index dec946bab..eceb0976c 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java +++ b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java @@ -141,7 +141,7 @@ public void subscribe(CoreSubscriber actual) { p.release(); if (interceptor != null) { - interceptor.onCancel(streamId); + interceptor.onCancel(streamId, FrameType.REQUEST_FNF); } return; @@ -153,7 +153,7 @@ public void subscribe(CoreSubscriber actual) { lazyTerminate(STATE, this); if (interceptor != null) { - interceptor.onTerminate(streamId, e); + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, e); } actual.onError(e); @@ -163,7 +163,7 @@ public void subscribe(CoreSubscriber actual) { lazyTerminate(STATE, this); if (interceptor != null) { - interceptor.onTerminate(streamId, null); + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, null); } actual.onComplete(); @@ -262,7 +262,7 @@ public Void block() { lazyTerminate(STATE, this); if (interceptor != null) { - interceptor.onTerminate(streamId, e); + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, e); } throw Exceptions.propagate(e); @@ -271,7 +271,7 @@ public Void block() { lazyTerminate(STATE, this); if (interceptor != null) { - interceptor.onTerminate(streamId, null); + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, null); } return null; diff --git a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java index 889c98fde..e76fdf9ed 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java @@ -21,6 +21,7 @@ import io.netty.util.ReferenceCountUtil; import io.rsocket.Payload; import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.plugins.RequestInterceptor; import org.reactivestreams.Subscription; @@ -101,7 +102,7 @@ public void onNext(Void voidVal) {} public void onError(Throwable t) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(this.streamId, t); + requestInterceptor.onTerminate(this.streamId, FrameType.REQUEST_FNF, t); } logger.debug("Dropped Outbound error", t); @@ -111,7 +112,7 @@ public void onError(Throwable t) { public void onComplete() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(this.streamId, null); + requestInterceptor.onTerminate(this.streamId, FrameType.REQUEST_FNF, null); } } @@ -131,7 +132,7 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, t); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_FNF, t); } logger.debug("Reassembly has failed", t); @@ -151,7 +152,7 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(this.streamId, t); + requestInterceptor.onTerminate(this.streamId, FrameType.REQUEST_FNF, t); } logger.debug("Reassembly has failed", t); @@ -175,7 +176,7 @@ public final void handleCancel() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onCancel(streamId); + requestInterceptor.onCancel(streamId, FrameType.REQUEST_FNF); } } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java index 8a57820c5..9b2936444 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java @@ -260,7 +260,7 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { this.inboundDone = true; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, t); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); } this.inboundSubscriber.onError(t); @@ -281,7 +281,7 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { connection.sendFrame(streamId, cancelFrame); if (requestInterceptor != null) { - requestInterceptor.onCancel(streamId); + requestInterceptor.onCancel(streamId, FrameType.REQUEST_CHANNEL); } return; } @@ -364,7 +364,7 @@ void propagateErrorSafely(Throwable t) { if (!this.inboundDone) { final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(this.streamId, t); + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); } this.inboundDone = true; @@ -386,7 +386,7 @@ public final void cancel() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onCancel(this.streamId); + requestInterceptor.onCancel(this.streamId, FrameType.REQUEST_CHANNEL); } } @@ -449,7 +449,7 @@ public void onError(Throwable t) { synchronized (this) { final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(streamId, t); + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); } this.inboundDone = true; @@ -492,7 +492,7 @@ public void onComplete() { if (isInboundTerminated) { final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(streamId, null); + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, null); } } } @@ -515,7 +515,7 @@ public final void handleComplete() { final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(streamId, null); + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, null); } } @@ -538,7 +538,7 @@ public final void handleError(Throwable cause) { } else if (isInboundTerminated(previousState)) { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(this.streamId, cause); + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, cause); } Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); @@ -555,7 +555,7 @@ public final void handleError(Throwable cause) { final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(streamId, cause); + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, cause); } this.inboundSubscriber.onError(cause); @@ -599,7 +599,7 @@ public void handleCancel() { if (inboundTerminated) { final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(this.streamId, null); + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, null); } } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java index 9d4cd5f1e..c52fdca25 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java @@ -310,7 +310,7 @@ public void cancel() { if (isOutboundTerminated) { final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(streamId, null); + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, null); } } } @@ -337,7 +337,7 @@ public final void handleCancel() { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onCancel(this.streamId); + interceptor.onCancel(this.streamId, FrameType.REQUEST_CHANNEL); } return; } @@ -349,7 +349,7 @@ public final void handleCancel() { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onCancel(this.streamId); + interceptor.onCancel(this.streamId, FrameType.REQUEST_CHANNEL); } } @@ -464,7 +464,7 @@ public final void handleError(Throwable t) { final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(this.streamId, t); + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); } } @@ -490,7 +490,7 @@ public void handleComplete() { if (isOutboundTerminated) { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(this.streamId, null); + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, null); } } } @@ -514,7 +514,7 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) } else if (isOutboundTerminated(previousState)) { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(this.streamId, t); + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); } Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); @@ -530,7 +530,7 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(streamId, t); + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); } return; } @@ -572,7 +572,7 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) } else if (isOutboundTerminated(previousState)) { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(this.streamId, e); + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, e); } Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); @@ -591,7 +591,7 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(streamId, e); + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); } return; @@ -620,7 +620,7 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) } else if (isOutboundTerminated(previousState)) { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(this.streamId, t); + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); } Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); @@ -638,7 +638,7 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(streamId, t); + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); } return; @@ -690,7 +690,7 @@ public void onNext(Payload p) { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(streamId, e); + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); } Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); @@ -705,7 +705,7 @@ public void onNext(Payload p) { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(streamId, e); + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); } return; } @@ -720,7 +720,7 @@ public void onNext(Payload p) { } else if (isOutboundTerminated(previousState)) { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(streamId, e); + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); } Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); @@ -736,7 +736,7 @@ public void onNext(Payload p) { final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(streamId, e); + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); } return; } @@ -749,7 +749,7 @@ public void onNext(Payload p) { long previousState = this.tryTerminate(false); final RequestInterceptor interceptor = requestInterceptor; if (interceptor != null && !isTerminated(previousState)) { - interceptor.onTerminate(streamId, t); + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); } } } @@ -810,7 +810,7 @@ && isFirstFrameSent(previousState) final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(streamId, t); + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); } } @@ -840,7 +840,7 @@ public void onComplete() { if (isInboundTerminated) { final RequestInterceptor interceptor = this.requestInterceptor; if (interceptor != null) { - interceptor.onTerminate(streamId, null); + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, null); } } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java index f3c52f648..850298a2a 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java @@ -185,7 +185,7 @@ void sendFirstPayload(Payload payload, long initialRequestN) { sm.remove(streamId, this); if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, e); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, e); } this.actual.onError(e); @@ -204,7 +204,7 @@ void sendFirstPayload(Payload payload, long initialRequestN) { connection.sendFrame(streamId, cancelFrame); if (requestInterceptor != null) { - requestInterceptor.onCancel(streamId); + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); } } } @@ -226,7 +226,7 @@ public final void cancel() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onCancel(streamId); + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); } } else if (!hasRequested(previousState)) { this.payload.release(); @@ -253,7 +253,7 @@ public final void handlePayload(Payload value) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, null); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); } final CoreSubscriber a = this.actual; @@ -279,7 +279,7 @@ public final void handleComplete() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, null); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); } this.actual.onComplete(); @@ -307,7 +307,7 @@ public final void handleError(Throwable cause) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, cause); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, cause); } this.actual.onError(cause); diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java index 648afff13..3d9d020ff 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java @@ -145,7 +145,7 @@ public void onNext(@Nullable Payload p) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, null); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); } return; } @@ -165,7 +165,7 @@ public void onNext(@Nullable Payload p) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, e); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, e); } return; } @@ -181,7 +181,7 @@ public void onNext(@Nullable Payload p) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, e); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, e); } return; } @@ -191,14 +191,14 @@ public void onNext(@Nullable Payload p) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, null); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); } } catch (Throwable t) { currentSubscription.cancel(); final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, t); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); } } } @@ -228,7 +228,7 @@ public void onError(Throwable t) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, t); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); } } @@ -260,7 +260,7 @@ public void handleCancel() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onCancel(streamId); + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); } return; } @@ -276,7 +276,7 @@ public void handleCancel() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onCancel(streamId); + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); } } @@ -310,7 +310,7 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, t); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); } return; } @@ -341,7 +341,7 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, t); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); } return; } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java index 3608eaf52..47e8c1610 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java @@ -200,7 +200,7 @@ void sendFirstPayload(Payload payload, long initialRequestN) { sm.remove(streamId, this); if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, t); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); } this.inboundSubscriber.onError(t); @@ -219,7 +219,7 @@ void sendFirstPayload(Payload payload, long initialRequestN) { connection.sendFrame(streamId, cancelFrame); if (requestInterceptor != null) { - requestInterceptor.onCancel(streamId); + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); } return; } @@ -259,7 +259,7 @@ public final void cancel() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onCancel(streamId); + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); } } else if (!hasRequested(previousState)) { // no need to send anything, since the first request has not happened @@ -290,11 +290,12 @@ public final void handleComplete() { return; } - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, null); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, null); } this.inboundSubscriber.onComplete(); @@ -315,13 +316,14 @@ public final void handleError(Throwable cause) { return; } - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); ReassemblyUtils.synchronizedRelease(this, previousState); final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, cause); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, cause); } this.inboundSubscriber.onError(cause); diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java index 6b06bc119..774fae9e5 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java @@ -146,7 +146,7 @@ public void onNext(Payload p) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, e); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, e); } return; } @@ -164,7 +164,7 @@ public void onNext(Payload p) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, e); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, e); } return; } @@ -178,7 +178,7 @@ public void onNext(Payload p) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, t); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); } } } @@ -229,7 +229,7 @@ public void onError(Throwable t) { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, t); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); } } @@ -253,7 +253,7 @@ public void onComplete() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, null); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, null); } } @@ -285,7 +285,7 @@ public final void handleCancel() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onCancel(streamId); + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); } return; } @@ -301,7 +301,7 @@ public final void handleCancel() { final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onCancel(streamId); + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); } } @@ -336,7 +336,7 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, e); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, e); } logger.debug("Reassembly has failed", e); @@ -368,7 +368,7 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas final RequestInterceptor requestInterceptor = this.requestInterceptor; if (requestInterceptor != null) { - requestInterceptor.onTerminate(streamId, t); + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); } logger.debug("Reassembly has failed", t); diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java index b4e1a1ba3..d455c79ba 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java @@ -40,12 +40,12 @@ public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metad } @Override - public void onTerminate(int streamId, @Nullable Throwable cause) { + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable cause) { final RequestInterceptor[] requestInterceptors = this.requestInterceptors; for (int i = 0; i < requestInterceptors.length; i++) { final RequestInterceptor requestInterceptor = requestInterceptors[i]; try { - requestInterceptor.onTerminate(streamId, cause); + requestInterceptor.onTerminate(streamId, requestType, cause); } catch (Throwable t) { Operators.onErrorDropped(t, Context.empty()); } @@ -53,12 +53,12 @@ public void onTerminate(int streamId, @Nullable Throwable cause) { } @Override - public void onCancel(int streamId) { + public void onCancel(int streamId, FrameType requestType) { final RequestInterceptor[] requestInterceptors = this.requestInterceptors; for (int i = 0; i < requestInterceptors.length; i++) { final RequestInterceptor requestInterceptor = requestInterceptors[i]; try { - requestInterceptor.onCancel(streamId); + requestInterceptor.onCancel(streamId, requestType); } catch (Throwable t) { Operators.onErrorDropped(t, Context.empty()); } @@ -121,18 +121,18 @@ public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metad } @Override - public void onTerminate(int streamId, @Nullable Throwable cause) { + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable cause) { try { - requestInterceptor.onTerminate(streamId, cause); + requestInterceptor.onTerminate(streamId, requestType, cause); } catch (Throwable t) { Operators.onErrorDropped(t, Context.empty()); } } @Override - public void onCancel(int streamId) { + public void onCancel(int streamId, FrameType requestType) { try { - requestInterceptor.onCancel(streamId); + requestInterceptor.onCancel(streamId, requestType); } catch (Throwable t) { Operators.onErrorDropped(t, Context.empty()); } diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java index 5da850837..08131b39d 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java @@ -33,21 +33,27 @@ public interface RequestInterceptor extends Disposable { /** * Method which is being invoked once a successfully accepted request is terminated. This method * can be invoked only after the {@link #onStart(int, FrameType, ByteBuf)} method. This method is - * exclusive with {@link #onCancel(int)}. + * exclusive with {@link #onCancel(int, FrameType)}. * * @param streamId used by this request + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} * @param t with which this finished has terminated. Must be one of the following signals */ - void onTerminate(int streamId, @Nullable Throwable t); + void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t); /** * Method which is being invoked once a successfully accepted request is cancelled. This method * can be invoked only after the {@link #onStart(int, FrameType, ByteBuf)} method. This method is - * exclusive with {@link #onTerminate(int, Throwable)}. + * exclusive with {@link #onTerminate(int, FrameType, Throwable)}. * + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} * @param streamId used by this request */ - void onCancel(int streamId); + void onCancel(int streamId, FrameType requestType); /** * Method which is being invoked on the request rejection. This method is being called only if the diff --git a/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java b/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java index 24a035b78..6f156a380 100644 --- a/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java +++ b/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java @@ -520,12 +520,13 @@ public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metad } @Override - public void onTerminate(int streamId, @Nullable Throwable terminalSignal) { + public void onTerminate( + int streamId, FrameType requestType, @Nullable Throwable terminalSignal) { throw new RuntimeException("testOnTerminate"); } @Override - public void onCancel(int streamId) { + public void onCancel(int streamId, FrameType requestType) { throw new RuntimeException("testOnCancel"); } @@ -620,12 +621,13 @@ public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metad } @Override - public void onTerminate(int streamId, @Nullable Throwable terminalSignal) { + public void onTerminate( + int streamId, FrameType requestType, @Nullable Throwable terminalSignal) { throw new RuntimeException("testOnTerminate"); } @Override - public void onCancel(int streamId) { + public void onCancel(int streamId, FrameType requestType) { throw new RuntimeException("testOnTerminate"); } diff --git a/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java b/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java index fe9de7ce1..8261b3f49 100644 --- a/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java +++ b/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java @@ -22,14 +22,15 @@ public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metad } @Override - public void onTerminate(int streamId, @Nullable Throwable t) { + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t) { events.add( - new Event(t == null ? EventType.ON_COMPLETE : EventType.ON_ERROR, streamId, null, t)); + new Event( + t == null ? EventType.ON_COMPLETE : EventType.ON_ERROR, streamId, requestType, t)); } @Override - public void onCancel(int streamId) { - events.add(new Event(EventType.ON_CANCEL, streamId, null, null)); + public void onCancel(int streamId, FrameType requestType) { + events.add(new Event(EventType.ON_CANCEL, streamId, requestType, null)); } @Override From 626a9f5b6ee5153b32c6ba4c8ca23b8d561ce759 Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Thu, 22 Oct 2020 18:09:17 +0300 Subject: [PATCH 2/2] migrates to RequestInterceptor to track Stats Signed-off-by: Oleh Dokuka --- benchmarks/README.md | 2 +- .../loadbalance/BaseWeightedStats.java | 221 ++++ .../ClientLoadbalanceStrategy.java | 14 + .../java/io/rsocket/loadbalance/Ewma.java | 24 +- .../loadbalance/FluxDeferredResolution.java | 8 +- .../rsocket/loadbalance/FrugalQuantile.java | 10 +- .../rsocket/loadbalance/Int2LongHashMap.java | 1005 +++++++++++++++++ .../loadbalance/LoadbalanceRSocketClient.java | 9 +- .../java/io/rsocket/loadbalance/Median.java | 1 + ...eightedRSocket.java => PooledRSocket.java} | 140 +-- .../io/rsocket/loadbalance/RSocketPool.java | 56 +- .../java/io/rsocket/loadbalance/Stats.java | 308 ----- .../WeightedLoadbalanceStrategy.java | 120 +- .../rsocket/loadbalance/WeightedRSocket.java | 23 - .../io/rsocket/loadbalance/WeightedStats.java | 19 + .../WeightedStatsRequestInterceptor.java | 91 ++ .../RoundRobinRSocketLoadbalancerExample.java | 3 +- 17 files changed, 1521 insertions(+), 533 deletions(-) create mode 100644 rsocket-core/src/main/java/io/rsocket/loadbalance/BaseWeightedStats.java create mode 100644 rsocket-core/src/main/java/io/rsocket/loadbalance/ClientLoadbalanceStrategy.java create mode 100644 rsocket-core/src/main/java/io/rsocket/loadbalance/Int2LongHashMap.java rename rsocket-core/src/main/java/io/rsocket/loadbalance/{PooledWeightedRSocket.java => PooledRSocket.java} (59%) delete mode 100644 rsocket-core/src/main/java/io/rsocket/loadbalance/Stats.java delete mode 100644 rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedRSocket.java create mode 100644 rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStats.java create mode 100644 rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRequestInterceptor.java diff --git a/benchmarks/README.md b/benchmarks/README.md index 6ba6755a6..656e2de4b 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -17,7 +17,7 @@ Specify extra profilers: Prominent profilers (for full list call `jmhProfilers` task): - comp - JitCompilations, tune your iterations - stack - which methods used most time -- gc - print garbage collection stats +- gc - print garbage collection defaultWeightedStats - hs_thr - thread usage Change report format from JSON to one of [CSV, JSON, NONE, SCSV, TEXT]: diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/BaseWeightedStats.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/BaseWeightedStats.java new file mode 100644 index 000000000..6514244c3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/BaseWeightedStats.java @@ -0,0 +1,221 @@ +package io.rsocket.loadbalance; + +import io.rsocket.util.Clock; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + +/** + * The base implementation of the {@link WeightedStats} interface + * + * @since 1.1 + */ +public class BaseWeightedStats implements WeightedStats { + + private static final double DEFAULT_LOWER_QUANTILE = 0.5; + private static final double DEFAULT_HIGHER_QUANTILE = 0.8; + private static final int INACTIVITY_FACTOR = 500; + private static final long DEFAULT_INITIAL_INTER_ARRIVAL_TIME = + Clock.unit().convert(1L, TimeUnit.SECONDS); + + private static final double STARTUP_PENALTY = Long.MAX_VALUE >> 12; + + private final Quantile lowerQuantile; + private final Quantile higherQuantile; + private final Ewma availabilityPercentage; + private final Median median; + private final Ewma interArrivalTime; + + private final long tau; + private final long inactivityFactor; + + private long errorStamp; // last we got an error + private long stamp; // last timestamp we sent a request + private long stamp0; // last timestamp we sent a request or receive a response + private long duration; // instantaneous cumulative duration + + private double availability = 1.0; + + private volatile int pendingRequests; // instantaneous rate + private static final AtomicIntegerFieldUpdater PENDING_REQUESTS = + AtomicIntegerFieldUpdater.newUpdater(BaseWeightedStats.class, "pendingRequests"); + private volatile int pendingStreams; // number of active streams + private static final AtomicIntegerFieldUpdater PENDING_STREAMS = + AtomicIntegerFieldUpdater.newUpdater(BaseWeightedStats.class, "pendingStreams"); + + protected BaseWeightedStats() { + this( + new FrugalQuantile(DEFAULT_LOWER_QUANTILE), + new FrugalQuantile(DEFAULT_HIGHER_QUANTILE), + INACTIVITY_FACTOR); + } + + private BaseWeightedStats( + Quantile lowerQuantile, Quantile higherQuantile, long inactivityFactor) { + this.lowerQuantile = lowerQuantile; + this.higherQuantile = higherQuantile; + this.inactivityFactor = inactivityFactor; + + long now = Clock.now(); + this.stamp = now; + this.errorStamp = now; + this.stamp0 = now; + this.duration = 0L; + this.pendingRequests = 0; + this.median = new Median(); + this.interArrivalTime = new Ewma(1, TimeUnit.MINUTES, DEFAULT_INITIAL_INTER_ARRIVAL_TIME); + this.availabilityPercentage = new Ewma(5, TimeUnit.SECONDS, 1.0); + this.tau = Clock.unit().convert((long) (5 / Math.log(2)), TimeUnit.SECONDS); + } + + @Override + public double lowerQuantileLatency() { + return lowerQuantile.estimation(); + } + + @Override + public double higherQuantileLatency() { + return higherQuantile.estimation(); + } + + @Override + public int pending() { + return pendingRequests + pendingStreams; + } + + @Override + public double availability() { + if (Clock.now() - stamp > tau) { + updateAvailability(1.0); + } + return availability * availabilityPercentage.value(); + } + + @Override + public double predictedLatency() { + final long now = Clock.now(); + final long elapsed; + + synchronized (this) { + elapsed = Math.max(now - stamp, 1L); + } + + final double latency; + final double prediction = median.estimation(); + + final int pending = this.pending(); + if (prediction == 0.0) { + if (pending == 0) { + latency = 0.0; // first request + } else { + // subsequent requests while we don't have any history + latency = STARTUP_PENALTY + pending; + } + } else if (pending == 0 && elapsed > inactivityFactor * interArrivalTime.value()) { + // if we did't see any data for a while, we decay the prediction by inserting + // artificial 0.0 into the median + median.insert(0.0); + latency = median.estimation(); + } else { + final double predicted = prediction * pending; + final double instant = instantaneous(now, pending); + + if (predicted < instant) { // NB: (0.0 < 0.0) == false + latency = instant / pending; // NB: pending never equal 0 here + } else { + // we are under the predictions + latency = prediction; + } + } + + return latency; + } + + long instantaneous(long now, int pending) { + return duration + (now - stamp0) * pending; + } + + void startStream() { + PENDING_STREAMS.incrementAndGet(this); + } + + void stopStream() { + PENDING_STREAMS.decrementAndGet(this); + } + + synchronized long startRequest() { + final long now = Clock.now(); + final int pendingRequests = this.pendingRequests; + + interArrivalTime.insert(now - stamp); + duration += Math.max(0, now - stamp0) * pendingRequests; + PENDING_REQUESTS.lazySet(this, pendingRequests + 1); + stamp = now; + stamp0 = now; + + return now; + } + + synchronized long stopRequest(long timestamp) { + final long now = Clock.now(); + final int pendingRequests = this.pendingRequests; + + duration += Math.max(0, now - stamp0) * pendingRequests - (now - timestamp); + PENDING_REQUESTS.lazySet(this, pendingRequests - 1); + stamp0 = now; + + return now; + } + + synchronized void record(double roundTripTime) { + median.insert(roundTripTime); + lowerQuantile.insert(roundTripTime); + higherQuantile.insert(roundTripTime); + } + + void updateAvailability(double value) { + availabilityPercentage.insert(value); + if (value == 0.0d) { + synchronized (this) { + errorStamp = Clock.now(); + } + } + } + + void setAvailability(double availability) { + this.availability = availability; + } + + @Override + public String toString() { + return "Stats{" + + "lowerQuantile=" + + lowerQuantile.estimation() + + ", higherQuantile=" + + higherQuantile.estimation() + + ", inactivityFactor=" + + inactivityFactor + + ", tau=" + + tau + + ", errorPercentage=" + + availabilityPercentage.value() + + ", pending=" + + pendingRequests + + ", errorStamp=" + + errorStamp + + ", stamp=" + + stamp + + ", stamp0=" + + stamp0 + + ", duration=" + + duration + + ", median=" + + median.estimation() + + ", interArrivalTime=" + + interArrivalTime.value() + + ", pendingStreams=" + + pendingStreams + + ", availability=" + + availability + + '}'; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/ClientLoadbalanceStrategy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/ClientLoadbalanceStrategy.java new file mode 100644 index 000000000..a35151fa6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/ClientLoadbalanceStrategy.java @@ -0,0 +1,14 @@ +package io.rsocket.loadbalance; + +import io.rsocket.core.RSocketConnector; + +/** + * Extension for {@link LoadbalanceStrategy} which allows pre-setup {@link RSocketConnector} for + * {@link LoadbalanceStrategy} needs + * + * @since 1.1 + */ +public interface ClientLoadbalanceStrategy extends LoadbalanceStrategy { + + void initialize(RSocketConnector connector); +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Ewma.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Ewma.java index 4812114dd..0f87f6510 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/Ewma.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/Ewma.java @@ -18,6 +18,7 @@ import io.rsocket.util.Clock; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; /** * Compute the exponential weighted moving average of a series of values. The time at which you @@ -28,20 +29,27 @@ * equal to (200 - 100)/2 = 150 (half of the distance between the new and the old value) */ class Ewma { - private final long tau; - private volatile long stamp; - private volatile double ewma; + + final long tau; + + volatile long stamp; + static final AtomicLongFieldUpdater STAMP = + AtomicLongFieldUpdater.newUpdater(Ewma.class, "stamp"); + volatile double ewma; public Ewma(long halfLife, TimeUnit unit, double initialValue) { this.tau = Clock.unit().convert((long) (halfLife / Math.log(2)), unit); - stamp = 0L; - ewma = initialValue; + + this.ewma = initialValue; + + STAMP.lazySet(this, 0L); } public synchronized void insert(double x) { - long now = Clock.now(); - double elapsed = Math.max(0, now - stamp); - stamp = now; + final long now = Clock.now(); + final double elapsed = Math.max(0, now - stamp); + + STAMP.lazySet(this, now); double w = Math.exp(-elapsed / tau); ewma = w * ewma + (1.0 - w) * x; diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/FluxDeferredResolution.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/FluxDeferredResolution.java index 337edc530..6c2b9c3ea 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/FluxDeferredResolution.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/FluxDeferredResolution.java @@ -83,7 +83,7 @@ public final Context currentContext() { @Nullable @Override - public Object scanUnsafe(Attr key) { + public final Object scanUnsafe(Attr key) { long state = this.requested; if (key == Attr.PARENT) { @@ -145,7 +145,7 @@ public final void onNext(Payload payload) { } @Override - public void onError(Throwable t) { + public final void onError(Throwable t) { if (this.done) { Operators.onErrorDropped(t, this.actual.currentContext()); return; @@ -156,7 +156,7 @@ public void onError(Throwable t) { } @Override - public void onComplete() { + public final void onComplete() { if (this.done) { return; } @@ -206,7 +206,7 @@ public final void request(long n) { } } - public void cancel() { + public final void cancel() { long state = REQUESTED.getAndSet(this, STATE_TERMINATED); if (state == STATE_TERMINATED) { return; diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/FrugalQuantile.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/FrugalQuantile.java index efa32ff83..a15d88529 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/FrugalQuantile.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/FrugalQuantile.java @@ -26,12 +26,14 @@ *

More info: http://blog.aggregateknowledge.com/2013/09/16/sketch-of-the-day-frugal-streaming/ */ class FrugalQuantile implements Quantile { - private final double increment; - volatile double estimate; + final double increment; + final SplittableRandom rnd; + int step; int sign; - private double quantile; - private SplittableRandom rnd; + double quantile; + + volatile double estimate; public FrugalQuantile(double quantile, double increment) { this.increment = increment; diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Int2LongHashMap.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Int2LongHashMap.java new file mode 100644 index 000000000..eebf82fe9 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/Int2LongHashMap.java @@ -0,0 +1,1005 @@ +/* + * Copyright 2014-2020 Real Logic Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import java.io.Serializable; +import java.util.AbstractCollection; +import java.util.AbstractSet; +import java.util.Arrays; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.function.Function; +import java.util.function.IntToLongFunction; +import reactor.util.annotation.Nullable; + +/** A open addressing with linear probing hash map specialised for primitive key and value pairs. */ +class Int2LongHashMap implements Map, Serializable { + static final float DEFAULT_LOAD_FACTOR = 0.55f; + static final int MIN_CAPACITY = 8; + private static final long serialVersionUID = -690554872053575793L; + + private final float loadFactor; + private final long missingValue; + private int resizeThreshold; + private int size = 0; + private final boolean shouldAvoidAllocation; + + private long[] entries; + private KeySet keySet; + private ValueCollection values; + private EntrySet entrySet; + + /** @param missingValue for the map that represents null. */ + public Int2LongHashMap(final long missingValue) { + this(MIN_CAPACITY, DEFAULT_LOAD_FACTOR, missingValue); + } + + /** + * @param initialCapacity for the map to override {@link #MIN_CAPACITY} + * @param loadFactor for the map to override {@link #DEFAULT_LOAD_FACTOR}. + * @param missingValue for the map that represents null. + */ + public Int2LongHashMap( + final int initialCapacity, final float loadFactor, final long missingValue) { + this(initialCapacity, loadFactor, missingValue, true); + } + + /** + * @param initialCapacity for the map to override {@link #MIN_CAPACITY} + * @param loadFactor for the map to override {@link #DEFAULT_LOAD_FACTOR}. + * @param missingValue for the map that represents null. + * @param shouldAvoidAllocation should allocation be avoided by caching iterators and map entries. + */ + public Int2LongHashMap( + final int initialCapacity, + final float loadFactor, + final long missingValue, + final boolean shouldAvoidAllocation) { + validateLoadFactor(loadFactor); + + this.loadFactor = loadFactor; + this.missingValue = missingValue; + this.shouldAvoidAllocation = shouldAvoidAllocation; + + capacity(findNextPositivePowerOfTwo(Math.max(MIN_CAPACITY, initialCapacity))); + } + + /** + * The value to be used as a null marker in the map. + * + * @return value to be used as a null marker in the map. + */ + public long missingValue() { + return missingValue; + } + + /** + * Get the load factor applied for resize operations. + * + * @return the load factor applied for resize operations. + */ + public float loadFactor() { + return loadFactor; + } + + /** + * Get the total capacity for the map to which the load factor will be a fraction of. + * + * @return the total capacity for the map. + */ + public int capacity() { + return entries.length >> 1; + } + + /** + * Get the actual threshold which when reached the map will resize. This is a function of the + * current capacity and load factor. + * + * @return the threshold when the map will resize. + */ + public int resizeThreshold() { + return resizeThreshold; + } + + /** {@inheritDoc} */ + public int size() { + return size; + } + + /** {@inheritDoc} */ + public boolean isEmpty() { + return size == 0; + } + + /** + * Get a value using provided key avoiding boxing. + * + * @param key lookup key. + * @return value associated with the key or {@link #missingValue()} if key is not found in the + * map. + */ + public long get(final int key) { + final int mask = entries.length - 1; + int index = evenHash(key, mask); + + long value = missingValue; + while (entries[index + 1] != missingValue) { + if (entries[index] == key) { + value = entries[index + 1]; + break; + } + + index = next(index, mask); + } + + return value; + } + + /** + * Put a key value pair in the map. + * + * @param key lookup key + * @param value new value, must not be {@link #missingValue()} + * @return previous value associated with the key, or {@link #missingValue()} if none found + * @throws IllegalArgumentException if value is {@link #missingValue()} + */ + public long put(final int key, final long value) { + if (value == missingValue) { + throw new IllegalArgumentException("cannot accept missingValue"); + } + + final int mask = entries.length - 1; + int index = evenHash(key, mask); + long oldValue = missingValue; + + while (entries[index + 1] != missingValue) { + if (entries[index] == key) { + oldValue = entries[index + 1]; + break; + } + + index = next(index, mask); + } + + if (oldValue == missingValue) { + ++size; + entries[index] = key; + } + + entries[index + 1] = value; + + increaseCapacity(); + + return oldValue; + } + + private void increaseCapacity() { + if (size > resizeThreshold) { + // entries.length = 2 * capacity + final int newCapacity = entries.length; + rehash(newCapacity); + } + } + + private void rehash(final int newCapacity) { + final long[] oldEntries = entries; + final int length = entries.length; + + capacity(newCapacity); + + final long[] newEntries = entries; + final int mask = entries.length - 1; + + for (int keyIndex = 0; keyIndex < length; keyIndex += 2) { + final long value = oldEntries[keyIndex + 1]; + if (value != missingValue) { + final int key = (int) oldEntries[keyIndex]; + int index = evenHash(key, mask); + + while (newEntries[index + 1] != missingValue) { + index = next(index, mask); + } + + newEntries[index] = key; + newEntries[index + 1] = value; + } + } + } + + /** + * Int primitive specialised containsKey. + * + * @param key the key to check. + * @return true if the map contains key as a key, false otherwise. + */ + public boolean containsKey(final int key) { + return get(key) != missingValue; + } + + /** + * Does the map contain the value. + * + * @param value to be tested against contained values. + * @return true if contained otherwise value. + */ + public boolean containsValue(final long value) { + boolean found = false; + if (value != missingValue) { + final int length = entries.length; + int remaining = size; + + for (int valueIndex = 1; remaining > 0 && valueIndex < length; valueIndex += 2) { + if (missingValue != entries[valueIndex]) { + if (value == entries[valueIndex]) { + found = true; + break; + } + --remaining; + } + } + } + + return found; + } + + /** {@inheritDoc} */ + public void clear() { + if (size > 0) { + Arrays.fill(entries, missingValue); + size = 0; + } + } + + /** + * Compact the backing arrays by rehashing with a capacity just larger than current size and + * giving consideration to the load factor. + */ + public void compact() { + final int idealCapacity = (int) Math.round(size() * (1.0d / loadFactor)); + rehash(findNextPositivePowerOfTwo(Math.max(MIN_CAPACITY, idealCapacity))); + } + + /** + * Primitive specialised version of {@link #computeIfAbsent(Object, Function)} + * + * @param key to search on. + * @param mappingFunction to provide a value if the get returns null. + * @return the value if found otherwise the missing value. + */ + public long computeIfAbsent(final int key, final IntToLongFunction mappingFunction) { + long value = get(key); + if (value == missingValue) { + value = mappingFunction.applyAsLong(key); + if (value != missingValue) { + put(key, value); + } + } + + return value; + } + + // ---------------- Boxed Versions Below ---------------- + + /** {@inheritDoc} */ + @Nullable + public Long get(final Object key) { + return valOrNull(get((int) key)); + } + + /** {@inheritDoc} */ + public Long put(final Integer key, final Long value) { + return valOrNull(put((int) key, (long) value)); + } + + /** {@inheritDoc} */ + public boolean containsKey(final Object key) { + return containsKey((int) key); + } + + /** {@inheritDoc} */ + public boolean containsValue(final Object value) { + return containsValue((long) value); + } + + /** {@inheritDoc} */ + public void putAll(final Map map) { + for (final Map.Entry entry : map.entrySet()) { + put(entry.getKey(), entry.getValue()); + } + } + + /** {@inheritDoc} */ + public KeySet keySet() { + if (null == keySet) { + keySet = new KeySet(); + } + + return keySet; + } + + /** {@inheritDoc} */ + public ValueCollection values() { + if (null == values) { + values = new ValueCollection(); + } + + return values; + } + + /** {@inheritDoc} */ + public EntrySet entrySet() { + if (null == entrySet) { + entrySet = new EntrySet(); + } + + return entrySet; + } + + /** {@inheritDoc} */ + @Nullable + public Long remove(final Object key) { + return valOrNull(remove((int) key)); + } + + /** + * Remove value from the map using given key avoiding boxing. + * + * @param key whose mapping is to be removed from the map. + * @return removed value or {@link #missingValue()} if key was not found in the map. + */ + public long remove(final int key) { + final int mask = entries.length - 1; + int keyIndex = evenHash(key, mask); + + long oldValue = missingValue; + while (entries[keyIndex + 1] != missingValue) { + if (entries[keyIndex] == key) { + oldValue = entries[keyIndex + 1]; + entries[keyIndex + 1] = missingValue; + size--; + + compactChain(keyIndex); + + break; + } + + keyIndex = next(keyIndex, mask); + } + + return oldValue; + } + + @SuppressWarnings("FinalParameters") + private void compactChain(int deleteKeyIndex) { + final int mask = entries.length - 1; + int keyIndex = deleteKeyIndex; + + while (true) { + keyIndex = next(keyIndex, mask); + if (entries[keyIndex + 1] == missingValue) { + break; + } + + final int hash = evenHash((int) entries[keyIndex], mask); + + if ((keyIndex < hash && (hash <= deleteKeyIndex || deleteKeyIndex <= keyIndex)) + || (hash <= deleteKeyIndex && deleteKeyIndex <= keyIndex)) { + entries[deleteKeyIndex] = entries[keyIndex]; + entries[deleteKeyIndex + 1] = entries[keyIndex + 1]; + + entries[keyIndex + 1] = missingValue; + deleteKeyIndex = keyIndex; + } + } + } + + /** + * Get the minimum value stored in the map. If the map is empty then it will return {@link + * #missingValue()} + * + * @return the minimum value stored in the map. + */ + public long minValue() { + final long missingValue = this.missingValue; + long min = size == 0 ? missingValue : Long.MAX_VALUE; + final int length = entries.length; + + for (int valueIndex = 1; valueIndex < length; valueIndex += 2) { + final long value = entries[valueIndex]; + if (value != missingValue) { + min = Math.min(min, value); + } + } + + return min; + } + + /** + * Get the maximum value stored in the map. If the map is empty then it will return {@link + * #missingValue()} + * + * @return the maximum value stored in the map. + */ + public long maxValue() { + final long missingValue = this.missingValue; + long max = size == 0 ? missingValue : Long.MIN_VALUE; + final int length = entries.length; + + for (int valueIndex = 1; valueIndex < length; valueIndex += 2) { + final long value = entries[valueIndex]; + if (value != missingValue) { + max = Math.max(max, value); + } + } + + return max; + } + + /** {@inheritDoc} */ + public String toString() { + if (isEmpty()) { + return "{}"; + } + + final EntryIterator entryIterator = new EntryIterator(); + entryIterator.reset(); + + final StringBuilder sb = new StringBuilder().append('{'); + while (true) { + entryIterator.next(); + sb.append(entryIterator.getIntKey()).append('=').append(entryIterator.getLongValue()); + if (!entryIterator.hasNext()) { + return sb.append('}').toString(); + } + sb.append(',').append(' '); + } + } + + /** + * Primitive specialised version of {@link #replace(Object, Object)} + * + * @param key key with which the specified value is associated + * @param value value to be associated with the specified key + * @return the previous value associated with the specified key, or {@link #missingValue()} if + * there was no mapping for the key. + */ + public long replace(final int key, final long value) { + long currentValue = get(key); + if (currentValue != missingValue) { + currentValue = put(key, value); + } + + return currentValue; + } + + /** + * Primitive specialised version of {@link #replace(Object, Object, Object)} + * + * @param key key with which the specified value is associated + * @param oldValue value expected to be associated with the specified key + * @param newValue value to be associated with the specified key + * @return {@code true} if the value was replaced + */ + public boolean replace(final int key, final long oldValue, final long newValue) { + final long curValue = get(key); + if (curValue != oldValue || curValue == missingValue) { + return false; + } + + put(key, newValue); + + return true; + } + + /** {@inheritDoc} */ + public boolean equals(final Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof Map)) { + return false; + } + + final Map that = (Map) o; + + return size == that.size() && entrySet().equals(that.entrySet()); + } + + public int hashCode() { + return entrySet().hashCode(); + } + + private static int next(final int index, final int mask) { + return (index + 2) & mask; + } + + private void capacity(final int newCapacity) { + final int entriesLength = newCapacity * 2; + if (entriesLength < 0) { + throw new IllegalStateException("max capacity reached at size=" + size); + } + + /*@DoNotSub*/ resizeThreshold = (int) (newCapacity * loadFactor); + entries = new long[entriesLength]; + Arrays.fill(entries, missingValue); + } + + @Nullable + private Long valOrNull(final long value) { + return value == missingValue ? null : value; + } + + // ---------------- Utility Classes ---------------- + + /** Base iterator implementation. */ + abstract class AbstractIterator implements Serializable { + private static final long serialVersionUID = 5262459454112462433L; + /** Is current position valid. */ + protected boolean isPositionValid = false; + + private int remaining; + private int positionCounter; + private int stopCounter; + + final void reset() { + isPositionValid = false; + remaining = Int2LongHashMap.this.size; + final long missingValue = Int2LongHashMap.this.missingValue; + final long[] entries = Int2LongHashMap.this.entries; + final int capacity = entries.length; + + int keyIndex = capacity; + if (entries[capacity - 1] != missingValue) { + for (int i = 1; i < capacity; i += 2) { + if (entries[i] == missingValue) { + keyIndex = i - 1; + break; + } + } + } + + stopCounter = keyIndex; + positionCounter = keyIndex + capacity; + } + + /** + * Returns position of the key of the current entry. + * + * @return key position. + */ + protected final int keyPosition() { + return positionCounter & entries.length - 1; + } + + /** + * Number of remaining elements. + * + * @return number of remaining elements. + */ + public int remaining() { + return remaining; + } + + /** + * Check if there are more elements remaining. + * + * @return {@code true} if {@code remaining > 0}. + */ + public boolean hasNext() { + return remaining > 0; + } + + /** + * Advance to the next entry. + * + * @throws NoSuchElementException if no more entries available. + */ + protected final void findNext() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + final long[] entries = Int2LongHashMap.this.entries; + final long missingValue = Int2LongHashMap.this.missingValue; + final int mask = entries.length - 1; + + for (int keyIndex = positionCounter - 2; keyIndex >= stopCounter; keyIndex -= 2) { + final int index = keyIndex & mask; + if (entries[index + 1] != missingValue) { + isPositionValid = true; + positionCounter = keyIndex; + --remaining; + return; + } + } + + isPositionValid = false; + throw new IllegalStateException(); + } + + /** {@inheritDoc} */ + public void remove() { + if (isPositionValid) { + final int position = keyPosition(); + entries[position + 1] = missingValue; + --size; + + compactChain(position); + + isPositionValid = false; + } else { + throw new IllegalStateException(); + } + } + } + + /** Iterator over keys which supports access to unboxed keys via {@link #nextValue()}. */ + public final class KeyIterator extends AbstractIterator + implements Iterator, Serializable { + private static final long serialVersionUID = 9151493609653852972L; + + public Integer next() { + return nextValue(); + } + + /** + * Return next key. + * + * @return next key. + */ + public int nextValue() { + findNext(); + return (int) entries[keyPosition()]; + } + } + + /** Iterator over values which supports access to unboxed values. */ + public final class ValueIterator extends AbstractIterator + implements Iterator, Serializable { + private static final long serialVersionUID = -5670291734793552927L; + + public Long next() { + return nextValue(); + } + + /** + * Return next value. + * + * @return next value. + */ + public long nextValue() { + findNext(); + return entries[keyPosition() + 1]; + } + } + + /** Iterator over entries which supports access to unboxed keys and values. */ + public final class EntryIterator extends AbstractIterator + implements Iterator>, Entry, Serializable { + private static final long serialVersionUID = 1744408438593481051L; + + public Integer getKey() { + return getIntKey(); + } + + /** + * Returns the key of the current entry. + * + * @return the key. + */ + public int getIntKey() { + return (int) entries[keyPosition()]; + } + + public Long getValue() { + return getLongValue(); + } + + /** + * Returns the value of the current entry. + * + * @return the value. + */ + public long getLongValue() { + return entries[keyPosition() + 1]; + } + + public Long setValue(final Long value) { + return setValue(value.longValue()); + } + + /** + * Sets the value of the current entry. + * + * @param value to be set. + * @return previous value of the entry. + */ + public long setValue(final long value) { + if (!isPositionValid) { + throw new IllegalStateException(); + } + + if (missingValue == value) { + throw new IllegalArgumentException(); + } + + final int keyPosition = keyPosition(); + final long prevValue = entries[keyPosition + 1]; + entries[keyPosition + 1] = value; + return prevValue; + } + + public Entry next() { + findNext(); + + if (shouldAvoidAllocation) { + return this; + } + + return allocateDuplicateEntry(); + } + + private Entry allocateDuplicateEntry() { + return new MapEntry(getIntKey(), getLongValue()); + } + + /** {@inheritDoc} */ + public int hashCode() { + return Integer.hashCode(getIntKey()) ^ Long.hashCode(getLongValue()); + } + + /** {@inheritDoc} */ + public boolean equals(final Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof Entry)) { + return false; + } + + final Entry that = (Entry) o; + + return Objects.equals(getKey(), that.getKey()) && Objects.equals(getValue(), that.getValue()); + } + + /** An {@link java.util.Map.Entry} implementation. */ + public final class MapEntry implements Entry { + private final int k; + private final long v; + + /** + * Constructs entry with given key and value. + * + * @param k key. + * @param v value. + */ + public MapEntry(final int k, final long v) { + this.k = k; + this.v = v; + } + + public Integer getKey() { + return k; + } + + public Long getValue() { + return v; + } + + public Long setValue(final Long value) { + return Int2LongHashMap.this.put(k, value.longValue()); + } + + public int hashCode() { + return Integer.hashCode(getIntKey()) ^ Long.hashCode(getLongValue()); + } + + public boolean equals(final Object o) { + if (!(o instanceof Map.Entry)) { + return false; + } + + final Entry e = (Entry) o; + + return (e.getKey() != null && e.getValue() != null) + && (e.getKey().equals(k) && e.getValue().equals(v)); + } + + public String toString() { + return k + "=" + v; + } + } + } + + /** Set of keys which supports optional cached iterators to avoid allocation. */ + public final class KeySet extends AbstractSet implements Serializable { + private static final long serialVersionUID = -7645453993079742625L; + private final KeyIterator keyIterator = shouldAvoidAllocation ? new KeyIterator() : null; + + /** {@inheritDoc} */ + public KeyIterator iterator() { + KeyIterator keyIterator = this.keyIterator; + if (null == keyIterator) { + keyIterator = new KeyIterator(); + } + + keyIterator.reset(); + + return keyIterator; + } + + /** {@inheritDoc} */ + public int size() { + return Int2LongHashMap.this.size(); + } + + /** {@inheritDoc} */ + public boolean isEmpty() { + return Int2LongHashMap.this.isEmpty(); + } + + /** {@inheritDoc} */ + public void clear() { + Int2LongHashMap.this.clear(); + } + + /** {@inheritDoc} */ + public boolean contains(final Object o) { + return contains((int) o); + } + + /** + * Checks if key is contained in the map without boxing. + * + * @param key to check. + * @return {@code true} if key is contained in this map. + */ + public boolean contains(final int key) { + return containsKey(key); + } + } + + /** Collection of values which supports optionally cached iterators to avoid allocation. */ + public final class ValueCollection extends AbstractCollection implements Serializable { + private static final long serialVersionUID = -8925598924781601919L; + private final ValueIterator valueIterator = shouldAvoidAllocation ? new ValueIterator() : null; + + /** {@inheritDoc} */ + public ValueIterator iterator() { + ValueIterator valueIterator = this.valueIterator; + if (null == valueIterator) { + valueIterator = new ValueIterator(); + } + + valueIterator.reset(); + + return valueIterator; + } + + /** {@inheritDoc} */ + public int size() { + return Int2LongHashMap.this.size(); + } + + /** {@inheritDoc} */ + public boolean contains(final Object o) { + return contains((long) o); + } + + /** + * Checks if the value is contained in the map. + * + * @param value to be checked. + * @return {@code true} if value is contained in this map. + */ + public boolean contains(final long value) { + return containsValue(value); + } + } + + /** Set of entries which supports optionally cached iterators to avoid allocation. */ + public final class EntrySet extends AbstractSet> + implements Serializable { + private static final long serialVersionUID = 63641283589916174L; + private final EntryIterator entryIterator = shouldAvoidAllocation ? new EntryIterator() : null; + + /** {@inheritDoc} */ + public EntryIterator iterator() { + EntryIterator entryIterator = this.entryIterator; + if (null == entryIterator) { + entryIterator = new EntryIterator(); + } + + entryIterator.reset(); + + return entryIterator; + } + + /** {@inheritDoc} */ + public int size() { + return Int2LongHashMap.this.size(); + } + + /** {@inheritDoc} */ + public boolean isEmpty() { + return Int2LongHashMap.this.isEmpty(); + } + + /** {@inheritDoc} */ + public void clear() { + Int2LongHashMap.this.clear(); + } + + /** {@inheritDoc} */ + public boolean contains(final Object o) { + if (!(o instanceof Entry)) { + return false; + } + final Entry entry = (Entry) o; + final Long value = get(entry.getKey()); + + return value != null && value.equals(entry.getValue()); + } + + /** {@inheritDoc} */ + public Object[] toArray() { + return toArray(new Object[size()]); + } + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + public T[] toArray(final T[] a) { + final T[] array = + a.length >= size + ? a + : (T[]) java.lang.reflect.Array.newInstance(a.getClass().getComponentType(), size); + final EntryIterator it = iterator(); + + for (int i = 0; i < array.length; i++) { + if (it.hasNext()) { + it.next(); + array[i] = (T) it.allocateDuplicateEntry(); + } else { + array[i] = null; + break; + } + } + + return array; + } + } + + private static int evenHash(final int value, final int mask) { + final int hash = (value << 1) - (value << 8); + + return hash & mask; + } + + private static void validateLoadFactor(final float loadFactor) { + if (loadFactor < 0.1f || loadFactor > 0.9f) { + throw new IllegalArgumentException( + "load factor must be in the range of 0.1 to 0.9: " + loadFactor); + } + } + + private static int findNextPositivePowerOfTwo(final int value) { + return 1 << (Integer.SIZE - Integer.numberOfLeadingZeros(value - 1)); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceRSocketClient.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceRSocketClient.java index 89ae01f18..8822632a0 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceRSocketClient.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceRSocketClient.java @@ -165,8 +165,15 @@ public Builder loadbalanceStrategy(LoadbalanceStrategy strategy) { /** Build the {@link LoadbalanceRSocketClient} instance. */ public LoadbalanceRSocketClient build() { + final RSocketConnector connector = initConnector(); + final LoadbalanceStrategy strategy = initLoadbalanceStrategy(); + + if (strategy instanceof ClientLoadbalanceStrategy) { + ((ClientLoadbalanceStrategy) strategy).initialize(connector); + } + return new LoadbalanceRSocketClient( - new RSocketPool(initConnector(), this.targetPublisher, initLoadbalanceStrategy())); + new RSocketPool(connector, this.targetPublisher, strategy)); } private RSocketConnector initConnector() { diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Median.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Median.java index 833bd5380..42b125b41 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/Median.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/Median.java @@ -18,6 +18,7 @@ /** This implementation gives better results because it considers more data-point. */ class Median extends FrugalQuantile { + public Median() { super(0.5, 1.0); } diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/PooledWeightedRSocket.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/PooledRSocket.java similarity index 59% rename from rsocket-core/src/main/java/io/rsocket/loadbalance/PooledWeightedRSocket.java rename to rsocket-core/src/main/java/io/rsocket/loadbalance/PooledRSocket.java index ad681087e..3d9011bf6 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/PooledWeightedRSocket.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/PooledRSocket.java @@ -28,29 +28,24 @@ import reactor.core.publisher.Operators; import reactor.util.context.Context; -/** Default implementation of {@link WeightedRSocket} stored in {@link RSocketPool} */ -final class PooledWeightedRSocket extends ResolvingOperator - implements CoreSubscriber, WeightedRSocket { +/** Default implementation of {@link RSocket} stored in {@link RSocketPool} */ +final class PooledRSocket extends ResolvingOperator + implements CoreSubscriber, RSocket { final RSocketPool parent; final Mono rSocketSource; final LoadbalanceTarget loadbalanceTarget; - final Stats stats; volatile Subscription s; - static final AtomicReferenceFieldUpdater S = - AtomicReferenceFieldUpdater.newUpdater(PooledWeightedRSocket.class, Subscription.class, "s"); + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(PooledRSocket.class, Subscription.class, "s"); - PooledWeightedRSocket( - RSocketPool parent, - Mono rSocketSource, - LoadbalanceTarget loadbalanceTarget, - Stats stats) { + PooledRSocket( + RSocketPool parent, Mono rSocketSource, LoadbalanceTarget loadbalanceTarget) { this.parent = parent; this.rSocketSource = rSocketSource; this.loadbalanceTarget = loadbalanceTarget; - this.stats = stats; } @Override @@ -113,13 +108,11 @@ protected void doSubscribe() { @Override protected void doOnValueResolved(RSocket value) { - stats.setAvailability(1.0); value.onClose().subscribe(null, t -> this.invalidate(), this::invalidate); } @Override protected void doOnValueExpired(RSocket value) { - stats.setAvailability(0.0); value.dispose(); this.dispose(); } @@ -133,7 +126,7 @@ public void dispose() { protected void doOnDispose() { final RSocketPool parent = this.parent; for (; ; ) { - final PooledWeightedRSocket[] sockets = parent.activeSockets; + final PooledRSocket[] sockets = parent.activeSockets; final int activeSocketsCount = sockets.length; int index = -1; @@ -149,7 +142,7 @@ protected void doOnDispose() { } final int lastIndex = activeSocketsCount - 1; - final PooledWeightedRSocket[] newSockets = new PooledWeightedRSocket[lastIndex]; + final PooledRSocket[] newSockets = new PooledRSocket[lastIndex]; if (index != 0) { System.arraycopy(sockets, 0, newSockets, 0, index); } @@ -162,43 +155,32 @@ protected void doOnDispose() { break; } } - stats.setAvailability(0.0); Operators.terminate(S, this); } @Override public Mono fireAndForget(Payload payload) { - return new RequestTrackingMonoInner<>(this, payload, FrameType.REQUEST_FNF); + return new MonoInner<>(this, payload, FrameType.REQUEST_FNF); } @Override public Mono requestResponse(Payload payload) { - return new RequestTrackingMonoInner<>(this, payload, FrameType.REQUEST_RESPONSE); + return new MonoInner<>(this, payload, FrameType.REQUEST_RESPONSE); } @Override public Flux requestStream(Payload payload) { - return new RequestTrackingFluxInner<>(this, payload, FrameType.REQUEST_STREAM); + return new FluxInner<>(this, payload, FrameType.REQUEST_STREAM); } @Override public Flux requestChannel(Publisher payloads) { - return new RequestTrackingFluxInner<>(this, payloads, FrameType.REQUEST_CHANNEL); + return new FluxInner<>(this, payloads, FrameType.REQUEST_CHANNEL); } @Override public Mono metadataPush(Payload payload) { - return new RequestTrackingMonoInner<>(this, payload, FrameType.METADATA_PUSH); - } - - /** - * Indicates number of active requests - * - * @return number of requests in progress - */ - @Override - public Stats stats() { - return stats; + return new MonoInner<>(this, payload, FrameType.METADATA_PUSH); } LoadbalanceTarget target() { @@ -207,15 +189,13 @@ LoadbalanceTarget target() { @Override public double availability() { - return stats.availability(); + final RSocket socket = valueIfResolved(); + return socket != null ? socket.availability() : 0.0d; } - static final class RequestTrackingMonoInner - extends MonoDeferredResolution { + static final class MonoInner extends MonoDeferredResolution { - long startTime; - - RequestTrackingMonoInner(PooledWeightedRSocket parent, Payload payload, FrameType requestType) { + MonoInner(PooledRSocket parent, Payload payload, FrameType requestType) { super(parent, payload, requestType); } @@ -249,58 +229,16 @@ public void accept(RSocket rSocket, Throwable t) { return; } - startTime = ((PooledWeightedRSocket) parent).stats.startRequest(); - source.subscribe((CoreSubscriber) this); } else { parent.add(this); } } - - @Override - public void onComplete() { - final long state = this.requested; - if (state != TERMINATED_STATE && REQUESTED.compareAndSet(this, state, TERMINATED_STATE)) { - final Stats stats = ((PooledWeightedRSocket) parent).stats; - final long now = stats.stopRequest(startTime); - stats.record(now - startTime); - super.onComplete(); - } - } - - @Override - public void onError(Throwable t) { - final long state = this.requested; - if (state != TERMINATED_STATE && REQUESTED.compareAndSet(this, state, TERMINATED_STATE)) { - Stats stats = ((PooledWeightedRSocket) parent).stats; - stats.stopRequest(startTime); - stats.recordError(0.0); - super.onError(t); - } - } - - @Override - public void cancel() { - long state = REQUESTED.getAndSet(this, STATE_TERMINATED); - if (state == STATE_TERMINATED) { - return; - } - - if (state == STATE_SUBSCRIBED) { - this.s.cancel(); - ((PooledWeightedRSocket) parent).stats.stopRequest(startTime); - } else { - this.parent.remove(this); - ReferenceCountUtil.safeRelease(this.payload); - } - } } - static final class RequestTrackingFluxInner - extends FluxDeferredResolution { + static final class FluxInner extends FluxDeferredResolution { - RequestTrackingFluxInner( - PooledWeightedRSocket parent, INPUT fluxOrPayload, FrameType requestType) { + FluxInner(PooledRSocket parent, INPUT fluxOrPayload, FrameType requestType) { super(parent, fluxOrPayload, requestType); } @@ -333,48 +271,10 @@ public void accept(RSocket rSocket, Throwable t) { return; } - ((PooledWeightedRSocket) parent).stats.startStream(); - source.subscribe(this); } else { parent.add(this); } } - - @Override - public void onComplete() { - final long state = this.requested; - if (state != TERMINATED_STATE && REQUESTED.compareAndSet(this, state, TERMINATED_STATE)) { - ((PooledWeightedRSocket) parent).stats.stopStream(); - super.onComplete(); - } - } - - @Override - public void onError(Throwable t) { - final long state = this.requested; - if (state != TERMINATED_STATE && REQUESTED.compareAndSet(this, state, TERMINATED_STATE)) { - ((PooledWeightedRSocket) parent).stats.stopStream(); - super.onError(t); - } - } - - @Override - public void cancel() { - long state = REQUESTED.getAndSet(this, STATE_TERMINATED); - if (state == STATE_TERMINATED) { - return; - } - - if (state == STATE_SUBSCRIBED) { - this.s.cancel(); - ((PooledWeightedRSocket) parent).stats.stopStream(); - } else { - this.parent.remove(this); - if (requestType == FrameType.REQUEST_STREAM) { - ReferenceCountUtil.safeRelease(this.fluxOrPayload); - } - } - } } } diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/RSocketPool.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/RSocketPool.java index dbd05abcb..733b06f3c 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/RSocketPool.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/RSocketPool.java @@ -28,7 +28,6 @@ import java.util.ListIterator; import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; -import java.util.function.Supplier; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; @@ -43,16 +42,15 @@ class RSocketPool extends ResolvingOperator final DeferredResolutionRSocket deferredResolutionRSocket = new DeferredResolutionRSocket(this); final RSocketConnector connector; final LoadbalanceStrategy loadbalanceStrategy; - final Supplier statsSupplier; - volatile PooledWeightedRSocket[] activeSockets; + volatile PooledRSocket[] activeSockets; - static final AtomicReferenceFieldUpdater ACTIVE_SOCKETS = + static final AtomicReferenceFieldUpdater ACTIVE_SOCKETS = AtomicReferenceFieldUpdater.newUpdater( - RSocketPool.class, PooledWeightedRSocket[].class, "activeSockets"); + RSocketPool.class, PooledRSocket[].class, "activeSockets"); - static final PooledWeightedRSocket[] EMPTY = new PooledWeightedRSocket[0]; - static final PooledWeightedRSocket[] TERMINATED = new PooledWeightedRSocket[0]; + static final PooledRSocket[] EMPTY = new PooledRSocket[0]; + static final PooledRSocket[] TERMINATED = new PooledRSocket[0]; volatile Subscription s; static final AtomicReferenceFieldUpdater S = @@ -64,11 +62,6 @@ public RSocketPool( LoadbalanceStrategy loadbalanceStrategy) { this.connector = connector; this.loadbalanceStrategy = loadbalanceStrategy; - if (loadbalanceStrategy instanceof WeightedLoadbalanceStrategy) { - this.statsSupplier = Stats::create; - } else { - this.statsSupplier = Stats::noOps; - } ACTIVE_SOCKETS.lazySet(this, EMPTY); @@ -105,8 +98,8 @@ public void onNext(List targets) { return; } - PooledWeightedRSocket[] previouslyActiveSockets; - PooledWeightedRSocket[] activeSockets; + PooledRSocket[] previouslyActiveSockets; + PooledRSocket[] activeSockets; for (; ; ) { HashMap rSocketSuppliersCopy = new HashMap<>(); @@ -117,11 +110,11 @@ public void onNext(List targets) { // checking intersection of active RSocket with the newly received set previouslyActiveSockets = this.activeSockets; - PooledWeightedRSocket[] nextActiveSockets = - new PooledWeightedRSocket[previouslyActiveSockets.length + rSocketSuppliersCopy.size()]; + PooledRSocket[] nextActiveSockets = + new PooledRSocket[previouslyActiveSockets.length + rSocketSuppliersCopy.size()]; int position = 0; for (int i = 0; i < previouslyActiveSockets.length; i++) { - PooledWeightedRSocket rSocket = previouslyActiveSockets[i]; + PooledRSocket rSocket = previouslyActiveSockets[i]; Integer index = rSocketSuppliersCopy.remove(rSocket.target()); if (index == null) { @@ -140,11 +133,7 @@ public void onNext(List targets) { // put newly create RSocket instance LoadbalanceTarget target = targets.get(index); nextActiveSockets[position++] = - new PooledWeightedRSocket( - this, - this.connector.connect(target.getTransport()), - target, - this.statsSupplier.get()); + new PooledRSocket(this, this.connector.connect(target.getTransport()), target); } } } @@ -152,11 +141,7 @@ public void onNext(List targets) { // going though brightly new rsocket for (LoadbalanceTarget target : rSocketSuppliersCopy.keySet()) { nextActiveSockets[position++] = - new PooledWeightedRSocket( - this, - this.connector.connect(target.getTransport()), - target, - this.statsSupplier.get()); + new PooledRSocket(this, this.connector.connect(target.getTransport()), target); } // shrank to actual length @@ -215,7 +200,7 @@ RSocket select() { @Nullable RSocket doSelect() { - WeightedRSocket[] sockets = this.activeSockets; + PooledRSocket[] sockets = this.activeSockets; if (sockets == EMPTY) { return null; } @@ -224,8 +209,15 @@ RSocket doSelect() { } @Override - public WeightedRSocket get(int index) { - return activeSockets[index]; + public RSocket get(int index) { + final PooledRSocket socket = activeSockets[index]; + final RSocket realValue = socket.valueIfResolved(); + + if (realValue != null) { + return realValue; + } + + return socket; } @Override @@ -423,7 +415,7 @@ public void clear() { } @Override - public WeightedRSocket set(int index, RSocket element) { + public RSocket set(int index, RSocket element) { throw new UnsupportedOperationException(); } @@ -433,7 +425,7 @@ public void add(int index, RSocket element) { } @Override - public WeightedRSocket remove(int index) { + public RSocket remove(int index) { throw new UnsupportedOperationException(); } diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Stats.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Stats.java deleted file mode 100644 index 2e9828938..000000000 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/Stats.java +++ /dev/null @@ -1,308 +0,0 @@ -package io.rsocket.loadbalance; - -import io.rsocket.Availability; -import io.rsocket.util.Clock; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicLongFieldUpdater; - -class Stats implements Availability { - - private static final double DEFAULT_LOWER_QUANTILE = 0.5; - private static final double DEFAULT_HIGHER_QUANTILE = 0.8; - private static final int INACTIVITY_FACTOR = 500; - private static final long DEFAULT_INITIAL_INTER_ARRIVAL_TIME = - Clock.unit().convert(1L, TimeUnit.SECONDS); - - private static final double STARTUP_PENALTY = Long.MAX_VALUE >> 12; - - private final Quantile lowerQuantile; - private final Quantile higherQuantile; - private final Ewma errorPercentage; - private final Median median; - private final Ewma interArrivalTime; - - private final long tau; - private final long inactivityFactor; - - private long errorStamp; // last we got an error - private long stamp; // last timestamp we sent a request - private long stamp0; // last timestamp we sent a request or receive a response - private long duration; // instantaneous cumulative duration - - private double availability = 1.0; - - private volatile int pending; // instantaneous rate - private volatile long pendingStreams; // number of active streams - private static final AtomicLongFieldUpdater PENDING_STREAMS = - AtomicLongFieldUpdater.newUpdater(Stats.class, "pendingStreams"); - - private Stats() { - this( - new FrugalQuantile(DEFAULT_LOWER_QUANTILE), - new FrugalQuantile(DEFAULT_HIGHER_QUANTILE), - INACTIVITY_FACTOR); - } - - private Stats(Quantile lowerQuantile, Quantile higherQuantile, long inactivityFactor) { - this.lowerQuantile = lowerQuantile; - this.higherQuantile = higherQuantile; - this.inactivityFactor = inactivityFactor; - - long now = Clock.now(); - this.stamp = now; - this.errorStamp = now; - this.stamp0 = now; - this.duration = 0L; - this.pending = 0; - this.median = new Median(); - this.interArrivalTime = new Ewma(1, TimeUnit.MINUTES, DEFAULT_INITIAL_INTER_ARRIVAL_TIME); - this.errorPercentage = new Ewma(5, TimeUnit.SECONDS, 1.0); - this.tau = Clock.unit().convert((long) (5 / Math.log(2)), TimeUnit.SECONDS); - } - - public double errorPercentage() { - return errorPercentage.value(); - } - - public double medianLatency() { - return median.estimation(); - } - - public double lowerQuantileLatency() { - return lowerQuantile.estimation(); - } - - public double higherQuantileLatency() { - return higherQuantile.estimation(); - } - - public double interArrivalTime() { - return interArrivalTime.value(); - } - - public int pending() { - return pending; - } - - public long lastTimeUsedMillis() { - return stamp0; - } - - @Override - public double availability() { - if (Clock.now() - stamp > tau) { - recordError(1.0); - } - return availability * errorPercentage.value(); - } - - public synchronized double predictedLatency() { - long now = Clock.now(); - long elapsed = Math.max(now - stamp, 1L); - - double weight; - double prediction = median.estimation(); - - if (prediction == 0.0) { - if (pending == 0) { - weight = 0.0; // first request - } else { - // subsequent requests while we don't have any history - weight = STARTUP_PENALTY + pending; - } - } else if (pending == 0 && elapsed > inactivityFactor * interArrivalTime.value()) { - // if we did't see any data for a while, we decay the prediction by inserting - // artificial 0.0 into the median - median.insert(0.0); - weight = median.estimation(); - } else { - double predicted = prediction * pending; - double instant = instantaneous(now); - - if (predicted < instant) { // NB: (0.0 < 0.0) == false - weight = instant / pending; // NB: pending never equal 0 here - } else { - // we are under the predictions - weight = prediction; - } - } - - return weight; - } - - synchronized long instantaneous(long now) { - return duration + (now - stamp0) * pending; - } - - public void startStream() { - PENDING_STREAMS.incrementAndGet(this); - } - - public void stopStream() { - PENDING_STREAMS.decrementAndGet(this); - } - - public synchronized long startRequest() { - long now = Clock.now(); - interArrivalTime.insert(now - stamp); - duration += Math.max(0, now - stamp0) * pending; - pending += 1; - stamp = now; - stamp0 = now; - return now; - } - - public synchronized long stopRequest(long timestamp) { - long now = Clock.now(); - duration += Math.max(0, now - stamp0) * pending - (now - timestamp); - pending -= 1; - stamp0 = now; - return now; - } - - public synchronized void record(double roundTripTime) { - median.insert(roundTripTime); - lowerQuantile.insert(roundTripTime); - higherQuantile.insert(roundTripTime); - } - - public synchronized void recordError(double value) { - errorPercentage.insert(value); - errorStamp = Clock.now(); - } - - public void setAvailability(double availability) { - this.availability = availability; - } - - @Override - public String toString() { - return "Stats{" - + "lowerQuantile=" - + lowerQuantile.estimation() - + ", higherQuantile=" - + higherQuantile.estimation() - + ", inactivityFactor=" - + inactivityFactor - + ", tau=" - + tau - + ", errorPercentage=" - + errorPercentage.value() - + ", pending=" - + pending - + ", errorStamp=" - + errorStamp - + ", stamp=" - + stamp - + ", stamp0=" - + stamp0 - + ", duration=" - + duration - + ", median=" - + median.estimation() - + ", interArrivalTime=" - + interArrivalTime.value() - + ", pendingStreams=" - + pendingStreams - + ", availability=" - + availability - + '}'; - } - - private static final class NoOpsStats extends Stats { - - static final Stats INSTANCE = new NoOpsStats(); - - private NoOpsStats() {} - - @Override - public double errorPercentage() { - return 0.0d; - } - - @Override - public double medianLatency() { - return 0.0d; - } - - @Override - public double lowerQuantileLatency() { - return 0.0d; - } - - @Override - public double higherQuantileLatency() { - return 0.0d; - } - - @Override - public double interArrivalTime() { - return 0; - } - - @Override - public int pending() { - return 0; - } - - @Override - public long lastTimeUsedMillis() { - return 0; - } - - @Override - public double availability() { - return 1.0d; - } - - @Override - public double predictedLatency() { - return 0.0d; - } - - @Override - long instantaneous(long now) { - return 0; - } - - @Override - public void startStream() {} - - @Override - public void stopStream() {} - - @Override - public long startRequest() { - return 0; - } - - @Override - public long stopRequest(long timestamp) { - return 0; - } - - @Override - public void record(double roundTripTime) {} - - @Override - public void recordError(double value) {} - - @Override - public String toString() { - return "NoOpsStats{}"; - } - } - - public static Stats noOps() { - return NoOpsStats.INSTANCE; - } - - public static Stats create() { - return new Stats(); - } - - public static Stats create( - Quantile lowerQuantile, Quantile higherQuantile, long inactivityFactor) { - return new Stats(lowerQuantile, higherQuantile, inactivityFactor); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategy.java index 03bc0530d..cdce957aa 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategy.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategy.java @@ -17,9 +17,14 @@ package io.rsocket.loadbalance; import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.plugins.RequestInterceptor; import java.util.List; import java.util.SplittableRandom; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Function; import reactor.util.annotation.Nullable; /** @@ -28,7 +33,7 @@ * * @since 1.1 */ -public class WeightedLoadbalanceStrategy implements LoadbalanceStrategy { +public class WeightedLoadbalanceStrategy implements ClientLoadbalanceStrategy { private static final double EXP_FACTOR = 4.0; @@ -36,18 +41,36 @@ public class WeightedLoadbalanceStrategy implements LoadbalanceStrategy { final int effort; final SplittableRandom splittableRandom; + final Function weightedStatsResolver; public WeightedLoadbalanceStrategy() { - this(EFFORT); + this(new DefaultWeightedStatsResolver()); } - public WeightedLoadbalanceStrategy(int effort) { - this(effort, new SplittableRandom(System.nanoTime())); + public WeightedLoadbalanceStrategy(Function weightedStatsResolver) { + this(EFFORT, weightedStatsResolver); } - public WeightedLoadbalanceStrategy(int effort, SplittableRandom splittableRandom) { + public WeightedLoadbalanceStrategy( + int effort, Function weightedStatsResolver) { + this(effort, new SplittableRandom(System.nanoTime()), weightedStatsResolver); + } + + public WeightedLoadbalanceStrategy( + int effort, + SplittableRandom splittableRandom, + Function weightedStatsResolver) { this.effort = effort; this.splittableRandom = splittableRandom; + this.weightedStatsResolver = weightedStatsResolver; + } + + @Override + public void initialize(RSocketConnector connector) { + final Function resolver = weightedStatsResolver; + if (resolver instanceof DefaultWeightedStatsResolver) { + ((DefaultWeightedStatsResolver) resolver).init(connector); + } } @Override @@ -55,18 +78,19 @@ public RSocket select(List sockets) { final int effort = this.effort; final int size = sockets.size(); - WeightedRSocket weightedRSocket; + RSocket weightedRSocket; + final Function weightedStatsResolver = this.weightedStatsResolver; switch (size) { case 1: - weightedRSocket = (WeightedRSocket) sockets.get(0); + weightedRSocket = sockets.get(0); break; case 2: { - WeightedRSocket rsc1 = (WeightedRSocket) sockets.get(0); - WeightedRSocket rsc2 = (WeightedRSocket) sockets.get(1); + RSocket rsc1 = sockets.get(0); + RSocket rsc2 = sockets.get(1); - double w1 = algorithmicWeight(rsc1); - double w2 = algorithmicWeight(rsc2); + double w1 = algorithmicWeight(rsc1, weightedStatsResolver.apply(rsc1)); + double w2 = algorithmicWeight(rsc2, weightedStatsResolver.apply(rsc2)); if (w1 < w2) { weightedRSocket = rsc2; } else { @@ -76,8 +100,8 @@ public RSocket select(List sockets) { break; default: { - WeightedRSocket rsc1 = null; - WeightedRSocket rsc2 = null; + RSocket rsc1 = null; + RSocket rsc2 = null; for (int i = 0; i < effort; i++) { int i1 = ThreadLocalRandom.current().nextInt(size); @@ -86,19 +110,26 @@ public RSocket select(List sockets) { if (i2 >= i1) { i2++; } - rsc1 = (WeightedRSocket) sockets.get(i1); - rsc2 = (WeightedRSocket) sockets.get(i2); + rsc1 = sockets.get(i1); + rsc2 = sockets.get(i2); if (rsc1.availability() > 0.0 && rsc2.availability() > 0.0) { break; } } - double w1 = algorithmicWeight(rsc1); - double w2 = algorithmicWeight(rsc2); - if (w1 < w2) { - weightedRSocket = rsc2; - } else { + if (rsc1 != null & rsc2 != null) { + double w1 = algorithmicWeight(rsc1, weightedStatsResolver.apply(rsc1)); + double w2 = algorithmicWeight(rsc2, weightedStatsResolver.apply(rsc2)); + + if (w1 < w2) { + weightedRSocket = rsc2; + } else { + weightedRSocket = rsc1; + } + } else if (rsc1 != null) { weightedRSocket = rsc1; + } else { + weightedRSocket = rsc2; } } } @@ -106,20 +137,19 @@ public RSocket select(List sockets) { return weightedRSocket; } - private static double algorithmicWeight(@Nullable final WeightedRSocket weightedRSocket) { - if (weightedRSocket == null - || weightedRSocket.isDisposed() - || weightedRSocket.availability() == 0.0) { + private static double algorithmicWeight( + RSocket rSocket, @Nullable final WeightedStats weightedStats) { + if (weightedStats == null || rSocket.isDisposed() || rSocket.availability() == 0.0) { return 0.0; } - final Stats stats = weightedRSocket.stats(); - final int pending = stats.pending(); - double latency = stats.predictedLatency(); + final int pending = weightedStats.pending(); + + double latency = weightedStats.predictedLatency(); - final double low = stats.lowerQuantileLatency(); + final double low = weightedStats.lowerQuantileLatency(); final double high = Math.max( - stats.higherQuantileLatency(), + weightedStats.higherQuantileLatency(), low * 1.001); // ensure higherQuantile > lowerQuantile + .1% final double bandWidth = Math.max(high - low, 1); @@ -129,11 +159,41 @@ private static double algorithmicWeight(@Nullable final WeightedRSocket weighted latency *= calculateFactor(latency, high, bandWidth); } - return weightedRSocket.availability() * 1.0 / (1.0 + latency * (pending + 1)); + return rSocket.availability() / (1.0d + latency * (pending + 1)); } private static double calculateFactor(final double u, final double l, final double bandWidth) { final double alpha = (u - l) / bandWidth; return Math.pow(1 + alpha, EXP_FACTOR); } + + static class DefaultWeightedStatsResolver implements Function { + + final ConcurrentMap rsocketsInterceptors = + new ConcurrentHashMap<>(); + + @Override + public WeightedStats apply(RSocket rSocket) { + return rsocketsInterceptors.get(rSocket); + } + + void init(RSocketConnector connector) { + connector.interceptors( + ir -> + ir.forRequester( + (Function) + rSocket -> { + final WeightedStatsRequestInterceptor interceptor = + new WeightedStatsRequestInterceptor() { + @Override + public void dispose() { + rsocketsInterceptors.remove(rSocket); + } + }; + rsocketsInterceptors.put(rSocket, interceptor); + + return interceptor; + })); + } + } } diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedRSocket.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedRSocket.java deleted file mode 100644 index 488a7134d..000000000 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedRSocket.java +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.loadbalance; - -import io.rsocket.RSocket; - -interface WeightedRSocket extends RSocket { - - Stats stats(); -} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStats.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStats.java new file mode 100644 index 000000000..b0cf02560 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStats.java @@ -0,0 +1,19 @@ +package io.rsocket.loadbalance; + +import io.rsocket.Availability; + +/** + * Representation of stats used by the {@link WeightedLoadbalanceStrategy} + * + * @since 1.1 + */ +public interface WeightedStats extends Availability { + + double higherQuantileLatency(); + + double lowerQuantileLatency(); + + int pending(); + + double predictedLatency(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRequestInterceptor.java new file mode 100644 index 000000000..f1e790309 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRequestInterceptor.java @@ -0,0 +1,91 @@ +package io.rsocket.loadbalance; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.RequestInterceptor; +import reactor.util.annotation.Nullable; + +/** + * A {@link RequestInterceptor} implementation + * + * @since 1.1 + */ +public class WeightedStatsRequestInterceptor extends BaseWeightedStats + implements RequestInterceptor { + + final Int2LongHashMap requestsStartTime = new Int2LongHashMap(-1); + + public WeightedStatsRequestInterceptor() { + super(); + } + + @Override + public final void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + switch (requestType) { + case REQUEST_FNF: + case REQUEST_RESPONSE: + final long startTime = startRequest(); + final Int2LongHashMap requestsStartTime = this.requestsStartTime; + synchronized (requestsStartTime) { + requestsStartTime.put(streamId, startTime); + } + break; + case REQUEST_STREAM: + case REQUEST_CHANNEL: + this.startStream(); + } + } + + @Override + public final void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t) { + switch (requestType) { + case REQUEST_FNF: + case REQUEST_RESPONSE: + long startTime; + final Int2LongHashMap requestsStartTime = this.requestsStartTime; + synchronized (requestsStartTime) { + startTime = requestsStartTime.remove(streamId); + } + long endTime = stopRequest(startTime); + if (t == null) { + record(endTime - startTime); + } + break; + case REQUEST_STREAM: + case REQUEST_CHANNEL: + stopStream(); + break; + } + + if (t != null) { + updateAvailability(0.0d); + } else { + updateAvailability(1.0d); + } + } + + @Override + public final void onCancel(int streamId, FrameType requestType) { + switch (requestType) { + case REQUEST_FNF: + case REQUEST_RESPONSE: + long startTime; + final Int2LongHashMap requestsStartTime = this.requestsStartTime; + synchronized (requestsStartTime) { + startTime = requestsStartTime.remove(streamId); + } + stopRequest(startTime); + break; + case REQUEST_STREAM: + case REQUEST_CHANNEL: + stopStream(); + break; + } + } + + @Override + public final void onReject(Throwable rejectionReason, FrameType requestType, ByteBuf metadata) {} + + @Override + public void dispose() {} +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/loadbalancer/RoundRobinRSocketLoadbalancerExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/loadbalancer/RoundRobinRSocketLoadbalancerExample.java index 27d10b472..feafdb7a6 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/loadbalancer/RoundRobinRSocketLoadbalancerExample.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/loadbalancer/RoundRobinRSocketLoadbalancerExample.java @@ -97,11 +97,10 @@ public static void main(String[] args) { }); RSocketClient rSocketClient = - LoadbalanceRSocketClient.builder(producer).roundRobinLoadbalanceStrategy().build(); + LoadbalanceRSocketClient.builder(producer).weightedLoadbalanceStrategy().build(); for (int i = 0; i < 10000; i++) { try { - rSocketClient.requestResponse(Mono.just(DefaultPayload.create("test" + i))).block(); } catch (Throwable t) { // no ops