From ca498ec1f1818914d0b3fc9966e724e6c82276ec Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Sat, 25 Jan 2020 01:21:06 +0200 Subject: [PATCH 01/11] provides failing test Signed-off-by: Oleh Dokuka --- .../test/java/io/rsocket/core/RSocketRequesterTest.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java index 20b1825fa..b6dbf71de 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -253,12 +253,17 @@ protected void hookOnSubscribe(Subscription subscription) {} ByteBuf initialFrame = iterator.next(); Assertions.assertThat(FrameHeaderFlyweight.frameType(initialFrame)).isEqualTo(REQUEST_CHANNEL); - Assertions.assertThat(RequestChannelFrameFlyweight.initialRequestN(initialFrame)) - .isEqualTo(Integer.MAX_VALUE); + Assertions.assertThat(RequestChannelFrameFlyweight.initialRequestN(initialFrame)).isEqualTo(1); Assertions.assertThat( RequestChannelFrameFlyweight.data(initialFrame).toString(CharsetUtil.UTF_8)) .isEqualTo("0"); + ByteBuf requestNFrame = iterator.next(); + + Assertions.assertThat(FrameHeaderFlyweight.frameType(requestNFrame)).isEqualTo(REQUEST_N); + Assertions.assertThat(RequestNFrameFlyweight.requestN(requestNFrame)) + .isEqualTo(Integer.MAX_VALUE); + Assertions.assertThat(iterator.hasNext()).isFalse(); } From c092f7358d03da87f9dd1fc72cd72a1802215ab2 Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Mon, 27 Jan 2020 00:22:29 +0200 Subject: [PATCH 02/11] provides requestChannel bugfix Signed-off-by: Oleh Dokuka --- .../rsocket/internal/FluxSwitchOnFirst.java | 657 +++++++++ .../RateLimitableRequestSubscriber.java | 239 ++++ .../internal/FluxSwitchOnFirstTest.java | 1274 +++++++++++++++++ .../java/io/rsocket/test/TransportTest.java | 3 +- 4 files changed, 2172 insertions(+), 1 deletion(-) create mode 100644 rsocket-core/src/main/java/io/rsocket/internal/FluxSwitchOnFirst.java create mode 100755 rsocket-core/src/main/java/io/rsocket/internal/RateLimitableRequestSubscriber.java create mode 100644 rsocket-core/src/test/java/io/rsocket/internal/FluxSwitchOnFirstTest.java diff --git a/rsocket-core/src/main/java/io/rsocket/internal/FluxSwitchOnFirst.java b/rsocket-core/src/main/java/io/rsocket/internal/FluxSwitchOnFirst.java new file mode 100644 index 000000000..9a6ef4214 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/FluxSwitchOnFirst.java @@ -0,0 +1,657 @@ +/* + * Copyright (c) 2011-2018 Pivotal Software Inc, All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.internal; + +import java.util.Objects; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.BiFunction; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxOperator; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Signal; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +/** + * @author Oleh Dokuka + * @param + * @param + */ +public final class FluxSwitchOnFirst extends FluxOperator { + static final int STATE_CANCELLED = -2; + static final int STATE_REQUESTED = -1; + + static final int STATE_INIT = 0; + static final int STATE_SUBSCRIBED_ONCE = 1; + + final BiFunction, Flux, Publisher> transformer; + final boolean cancelSourceOnComplete; + + volatile int once; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ONCE = + AtomicIntegerFieldUpdater.newUpdater(FluxSwitchOnFirst.class, "once"); + + public FluxSwitchOnFirst( + Flux source, + BiFunction, Flux, Publisher> transformer, + boolean cancelSourceOnComplete) { + super(source); + this.transformer = Objects.requireNonNull(transformer, "transformer"); + this.cancelSourceOnComplete = cancelSourceOnComplete; + } + + @Override + public int getPrefetch() { + return 1; + } + + @Override + @SuppressWarnings("unchecked") + public void subscribe(CoreSubscriber actual) { + Objects.requireNonNull(actual, "subscribe"); + if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { + if (actual instanceof Fuseable.ConditionalSubscriber) { + source.subscribe( + new SwitchOnFirstConditionalMain<>( + (Fuseable.ConditionalSubscriber) actual, + transformer, + cancelSourceOnComplete)); + return; + } + source.subscribe(new SwitchOnFirstMain<>(actual, transformer, cancelSourceOnComplete)); + } else { + Operators.error(actual, new IllegalStateException("Allows only a single Subscriber")); + } + } + + abstract static class AbstractSwitchOnFirstMain extends Flux + implements CoreSubscriber, Scannable, Subscription { + + final ControlSubscriber outer; + final BiFunction, Flux, Publisher> transformer; + + Subscription s; + Throwable throwable; + boolean first = true; + boolean done; + + volatile CoreSubscriber inner; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater INNER = + AtomicReferenceFieldUpdater.newUpdater( + AbstractSwitchOnFirstMain.class, CoreSubscriber.class, "inner"); + + volatile int state; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(AbstractSwitchOnFirstMain.class, "state"); + + @SuppressWarnings("unchecked") + AbstractSwitchOnFirstMain( + CoreSubscriber outer, + BiFunction, Flux, Publisher> transformer, + boolean cancelSourceOnComplete) { + this.outer = + outer instanceof Fuseable.ConditionalSubscriber + ? new SwitchOnFirstConditionalControlSubscriber<>( + this, (Fuseable.ConditionalSubscriber) outer, cancelSourceOnComplete) + : new SwitchOnFirstControlSubscriber<>(this, outer, cancelSourceOnComplete); + this.transformer = transformer; + } + + @Override + @Nullable + public Object scanUnsafe(Attr key) { + CoreSubscriber i = this.inner; + + if (key == Attr.CANCELLED) return i == CancelledSubscriber.INSTANCE; + if (key == Attr.TERMINATED) return done || i == CancelledSubscriber.INSTANCE; + + return null; + } + + @Override + public Context currentContext() { + CoreSubscriber actual = inner; + + if (actual != null) { + return actual.currentContext(); + } + + return outer.currentContext(); + } + + @Override + public void cancel() { + if (INNER.getAndSet(this, CancelledSubscriber.INSTANCE) != CancelledSubscriber.INSTANCE) { + s.cancel(); + } + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + outer.sendSubscription(); + s.request(1); + } + } + + @Override + public void onNext(T t) { + CoreSubscriber i = inner; + if (i == CancelledSubscriber.INSTANCE || done) { + Operators.onNextDropped(t, currentContext()); + return; + } + + if (i == null) { + Publisher result; + CoreSubscriber o = outer; + + try { + result = + Objects.requireNonNull( + transformer.apply(Signal.next(t, o.currentContext()), this), + "The transformer returned a null value"); + } catch (Throwable e) { + done = true; + Operators.error(o, Operators.onOperatorError(s, e, t, o.currentContext())); + return; + } + + first = false; + result.subscribe(o); + return; + } + + i.onNext(t); + } + + @Override + public void onError(Throwable t) { + CoreSubscriber i = inner; + if (i == CancelledSubscriber.INSTANCE || done) { + Operators.onErrorDropped(t, currentContext()); + return; + } + + throwable = t; + done = true; + + if (first && i == null) { + Publisher result; + CoreSubscriber o = outer; + + try { + result = + Objects.requireNonNull( + transformer.apply(Signal.error(t, o.currentContext()), this), + "The transformer returned a null value"); + } catch (Throwable e) { + done = true; + Operators.error(o, Operators.onOperatorError(s, e, t, o.currentContext())); + return; + } + + first = false; + result.subscribe(o); + } + + i = this.inner; + if (i != null) { + i.onError(t); + } + } + + @Override + public void onComplete() { + CoreSubscriber i = inner; + if (i == CancelledSubscriber.INSTANCE || done) { + return; + } + + done = true; + + if (first && i == null) { + Publisher result; + CoreSubscriber o = outer; + + try { + result = + Objects.requireNonNull( + transformer.apply(Signal.complete(o.currentContext()), this), + "The transformer returned a null value"); + } catch (Throwable e) { + done = true; + Operators.error(o, Operators.onOperatorError(s, e, null, o.currentContext())); + return; + } + + first = false; + result.subscribe(o); + } + + i = inner; + if (i != null) { + i.onComplete(); + } + } + + @Override + public void request(long n) { + if (Operators.validate(n)) { + s.request(n); + } + } + } + + static final class SwitchOnFirstMain extends AbstractSwitchOnFirstMain { + + SwitchOnFirstMain( + CoreSubscriber outer, + BiFunction, Flux, Publisher> transformer, + boolean cancelSourceOnComplete) { + super(outer, transformer, cancelSourceOnComplete); + } + + @Override + public void subscribe(CoreSubscriber actual) { + if (state == STATE_INIT && STATE.compareAndSet(this, STATE_INIT, STATE_SUBSCRIBED_ONCE)) { + if (done) { + if (throwable != null) { + Operators.error(actual, throwable); + } else { + Operators.complete(actual); + } + return; + } + INNER.lazySet(this, actual); + actual.onSubscribe(this); + } else { + Operators.error( + actual, new IllegalStateException("FluxSwitchOnFirst allows only one Subscriber")); + } + } + } + + static final class SwitchOnFirstConditionalMain extends AbstractSwitchOnFirstMain + implements Fuseable.ConditionalSubscriber { + + SwitchOnFirstConditionalMain( + Fuseable.ConditionalSubscriber outer, + BiFunction, Flux, Publisher> transformer, + boolean cancelSourceOnComplete) { + super(outer, transformer, cancelSourceOnComplete); + } + + @Override + public void subscribe(CoreSubscriber actual) { + if (state == STATE_INIT && STATE.compareAndSet(this, STATE_INIT, STATE_SUBSCRIBED_ONCE)) { + if (done) { + if (throwable != null) { + Operators.error(actual, throwable); + } else { + Operators.complete(actual); + } + return; + } + INNER.lazySet(this, Operators.toConditionalSubscriber(actual)); + actual.onSubscribe(this); + } else { + Operators.error( + actual, new IllegalStateException("FluxSwitchOnFirst allows only one Subscriber")); + } + } + + @Override + @SuppressWarnings("unchecked") + public boolean tryOnNext(T t) { + CoreSubscriber i = inner; + if (i == CancelledSubscriber.INSTANCE || done) { + Operators.onNextDropped(t, currentContext()); + return false; + } + + if (i == null) { + Publisher result; + CoreSubscriber o = outer; + + try { + result = + Objects.requireNonNull( + transformer.apply(Signal.next(t, o.currentContext()), this), + "The transformer returned a null value"); + } catch (Throwable e) { + done = true; + Operators.error(o, Operators.onOperatorError(s, e, t, o.currentContext())); + return false; + } + + first = false; + result.subscribe(o); + return true; + } + + return ((Fuseable.ConditionalSubscriber) i).tryOnNext(t); + } + } + + static final class SwitchOnFirstControlSubscriber + implements ControlSubscriber, Scannable, Subscription { + + final AbstractSwitchOnFirstMain parent; + final CoreSubscriber actual; + final boolean cancelSourceOnComplete; + + volatile long requested; + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(SwitchOnFirstControlSubscriber.class, "requested"); + + Subscription s; + + SwitchOnFirstControlSubscriber( + AbstractSwitchOnFirstMain parent, + CoreSubscriber actual, + boolean cancelSourceOnComplete) { + this.parent = parent; + this.actual = actual; + this.cancelSourceOnComplete = cancelSourceOnComplete; + } + + @Override + public Context currentContext() { + return actual.currentContext(); + } + + @Override + public void sendSubscription() { + actual.onSubscribe(this); + } + + @Override + public void onSubscribe(Subscription s) { + if (this.s == null && this.requested != STATE_CANCELLED) { + this.s = s; + + tryRequest(); + } else { + s.cancel(); + } + } + + @Override + public void onNext(T t) { + actual.onNext(t); + } + + @Override + public void onError(Throwable throwable) { + if (!parent.done) { + parent.cancel(); + } + + actual.onError(throwable); + } + + @Override + public void onComplete() { + if (!parent.done && cancelSourceOnComplete) { + parent.cancel(); + } + + actual.onComplete(); + } + + @Override + public void request(long n) { + long r = this.requested; + + if (r > STATE_REQUESTED) { + long u; + for (; ; ) { + if (r == Long.MAX_VALUE) { + return; + } + u = Operators.addCap(r, n); + if (REQUESTED.compareAndSet(this, r, u)) { + return; + } else { + r = requested; + + if (r < 0) { + break; + } + } + } + } + + if (r == STATE_CANCELLED) { + return; + } + + s.request(n); + } + + void tryRequest() { + final Subscription s = this.s; + long r = REQUESTED.getAndSet(this, -1); + + if (r > 0) { + s.request(r); + } + } + + @Override + public void cancel() { + final long state = REQUESTED.getAndSet(this, STATE_CANCELLED); + + if (state == STATE_CANCELLED) { + return; + } + + if (state == STATE_REQUESTED) { + s.cancel(); + } + + parent.cancel(); + } + + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) return parent; + if (key == Attr.ACTUAL) return actual; + + return null; + } + } + + static final class SwitchOnFirstConditionalControlSubscriber + implements ControlSubscriber, Scannable, Subscription, Fuseable.ConditionalSubscriber { + + final AbstractSwitchOnFirstMain parent; + final Fuseable.ConditionalSubscriber actual; + final boolean terminateUpstreamOnComplete; + + volatile long requested; + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater( + SwitchOnFirstConditionalControlSubscriber.class, "requested"); + + Subscription s; + + SwitchOnFirstConditionalControlSubscriber( + AbstractSwitchOnFirstMain parent, + Fuseable.ConditionalSubscriber actual, + boolean terminateUpstreamOnComplete) { + this.parent = parent; + this.actual = actual; + this.terminateUpstreamOnComplete = terminateUpstreamOnComplete; + } + + @Override + public void sendSubscription() { + actual.onSubscribe(this); + } + + @Override + public Context currentContext() { + return actual.currentContext(); + } + + @Override + public void onSubscribe(Subscription s) { + if (this.s == null && this.requested != STATE_CANCELLED) { + this.s = s; + + tryRequest(); + } else { + s.cancel(); + } + } + + @Override + public void onNext(T t) { + actual.onNext(t); + } + + @Override + public boolean tryOnNext(T t) { + return actual.tryOnNext(t); + } + + @Override + public void onError(Throwable throwable) { + if (!parent.done) { + parent.cancel(); + } + + actual.onError(throwable); + } + + @Override + public void onComplete() { + if (!parent.done && terminateUpstreamOnComplete) { + parent.cancel(); + } + + actual.onComplete(); + } + + @Override + public void request(long n) { + long r = this.requested; + + if (r > STATE_REQUESTED) { + long u; + for (; ; ) { + if (r == Long.MAX_VALUE) { + return; + } + u = Operators.addCap(r, n); + if (REQUESTED.compareAndSet(this, r, u)) { + return; + } else { + r = requested; + + if (r < 0) { + break; + } + } + } + } + + if (r == STATE_CANCELLED) { + return; + } + + s.request(n); + } + + void tryRequest() { + final Subscription s = this.s; + long r = REQUESTED.getAndSet(this, -1); + + if (r > 0) { + s.request(r); + } + } + + @Override + public void cancel() { + final long state = REQUESTED.getAndSet(this, STATE_CANCELLED); + + if (state == STATE_CANCELLED) { + return; + } + + if (state == STATE_REQUESTED) { + s.cancel(); + return; + } + + parent.cancel(); + } + + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) return parent; + if (key == Attr.ACTUAL) return actual; + + return null; + } + } + + interface ControlSubscriber extends CoreSubscriber { + + void sendSubscription(); + } + + static final class CancelledSubscriber implements CoreSubscriber { + + static final CancelledSubscriber INSTANCE = new CancelledSubscriber(); + + private CancelledSubscriber() {} + + @Override + public void onSubscribe(Subscription s) {} + + @Override + public void onNext(Object o) {} + + @Override + public void onError(Throwable t) {} + + @Override + public void onComplete() {} + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/RateLimitableRequestSubscriber.java b/rsocket-core/src/main/java/io/rsocket/internal/RateLimitableRequestSubscriber.java new file mode 100755 index 000000000..2fb18e1f7 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/RateLimitableRequestSubscriber.java @@ -0,0 +1,239 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.internal; + +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.publisher.Operators; +import reactor.core.publisher.SignalType; + +/** */ +public abstract class RateLimitableRequestSubscriber implements CoreSubscriber, Subscription { + + private final long prefetch; + private final long limit; + + private long externalRequested; // need sync + private int pendingToFulfil; // need sync since should be checked/zerroed in onNext + // and increased in request + private int deliveredElements; // no need to sync since increased zerroed only in + // the request method + + private volatile Subscription subscription; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater( + RateLimitableRequestSubscriber.class, Subscription.class, "subscription"); + + public RateLimitableRequestSubscriber(long prefetch) { + this.prefetch = prefetch; + this.limit = prefetch == Integer.MAX_VALUE ? Integer.MAX_VALUE : (prefetch - (prefetch >> 2)); + } + + protected void hookOnSubscribe(Subscription s) { + // NO-OP + } + + protected void hookOnNext(T value) { + // NO-OP + } + + protected void hookOnComplete() { + // NO-OP + } + + protected void hookOnError(Throwable throwable) { + throw Exceptions.errorCallbackNotImplemented(throwable); + } + + protected void hookOnCancel() { + // NO-OP + } + + protected void hookFinally(SignalType type) { + // NO-OP + } + + void safeHookFinally(SignalType type) { + try { + hookFinally(type); + } catch (Throwable finallyFailure) { + Operators.onErrorDropped(finallyFailure, currentContext()); + } + } + + @Override + public final void onSubscribe(Subscription s) { + if (Operators.validate(subscription, s)) { + this.subscription = s; + try { + hookOnSubscribe(s); + requestN(); + } catch (Throwable throwable) { + onError(Operators.onOperatorError(s, throwable, currentContext())); + } + } + } + + @Override + public final void onNext(T t) { + try { + hookOnNext(t); + + if (prefetch == Integer.MAX_VALUE) { + return; + } + + final long l = limit; + int d = deliveredElements + 1; + + if (d == l) { + d = 0; + final long r; + final Subscription s = subscription; + + if (s == null) { + return; + } + + synchronized (this) { + long er = externalRequested; + + if (er >= l) { + er -= l; + // keep pendingToFulfil as is since it is eq to prefetch + r = l; + } else { + pendingToFulfil -= l; + if (er > 0) { + r = er; + er = 0; + pendingToFulfil += r; + } else { + r = 0; + } + } + + externalRequested = er; + } + + if (r > 0) { + s.request(r); + } + } + + deliveredElements = d; + } catch (Throwable e) { + onError(e); + } + } + + @Override + public final void onError(Throwable t) { + Subscription s = S.getAndSet(this, Operators.cancelledSubscription()); + if (s == Operators.cancelledSubscription()) { + Operators.onErrorDropped(t, this.currentContext()); + return; + } + + try { + hookOnError(t); + } catch (Throwable e) { + e = Exceptions.addSuppressed(e, t); + Operators.onErrorDropped(e, currentContext()); + } finally { + safeHookFinally(SignalType.ON_ERROR); + } + } + + @Override + public final void onComplete() { + if (S.getAndSet(this, Operators.cancelledSubscription()) != Operators.cancelledSubscription()) { + // we're sure it has not been concurrently cancelled + try { + hookOnComplete(); + } catch (Throwable throwable) { + // onError itself will short-circuit due to the CancelledSubscription being set above + hookOnError(Operators.onOperatorError(throwable, currentContext())); + } finally { + safeHookFinally(SignalType.ON_COMPLETE); + } + } + } + + @Override + public final void request(long n) { + synchronized (this) { + long requested = externalRequested; + if (requested == Long.MAX_VALUE) { + return; + } + externalRequested = Operators.addCap(n, requested); + } + + requestN(); + } + + private void requestN() { + final long r; + final Subscription s = subscription; + + if (s == null) { + return; + } + + synchronized (this) { + final long er = externalRequested; + final long p = prefetch; + final int pendingFulfil = pendingToFulfil; + + if (er != Long.MAX_VALUE || p != Integer.MAX_VALUE) { + // shortcut + if (pendingFulfil == p) { + return; + } + + r = Math.min(p - pendingFulfil, er); + if (er != Long.MAX_VALUE) { + externalRequested -= r; + } + if (p != Integer.MAX_VALUE) { + pendingToFulfil += r; + } + } else { + r = Long.MAX_VALUE; + } + } + + if (r > 0) { + s.request(r); + } + } + + public final void cancel() { + if (Operators.terminate(S, this)) { + try { + hookOnCancel(); + } catch (Throwable throwable) { + hookOnError(Operators.onOperatorError(subscription, throwable, currentContext())); + } finally { + safeHookFinally(SignalType.CANCEL); + } + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/FluxSwitchOnFirstTest.java b/rsocket-core/src/test/java/io/rsocket/internal/FluxSwitchOnFirstTest.java new file mode 100644 index 000000000..7628a5304 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/internal/FluxSwitchOnFirstTest.java @@ -0,0 +1,1274 @@ +/* + * Copyright (c) 2011-2018 Pivotal Software Inc, All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import org.assertj.core.api.Assertions; +import org.junit.Test; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.Fuseable; +import reactor.core.publisher.EmitterProcessor; +import reactor.core.publisher.Flux; +import reactor.core.publisher.ReplayProcessor; +import reactor.core.publisher.Signal; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; +import reactor.util.context.Context; + +public class FluxSwitchOnFirstTest { + + @Test + public void shouldNotSubscribeTwice() { + Throwable[] throwables = new Throwable[1]; + CountDownLatch latch = new CountDownLatch(1); + StepVerifier.create( + Flux.just(1L) + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (s, f) -> { + RaceTestUtils.race( + () -> + f.subscribe( + __ -> {}, + t -> { + throwables[0] = t; + latch.countDown(); + }, + latch::countDown), + () -> + f.subscribe( + __ -> {}, + t -> { + throwables[0] = t; + latch.countDown(); + }, + latch::countDown)); + + return Flux.empty(); + }, + false))) + .expectSubscription() + .expectComplete() + .verify(); + + Assertions.assertThat(throwables[0]) + .isInstanceOf(IllegalStateException.class) + .hasMessage("FluxSwitchOnFirst allows only one Subscriber"); + } + + @Test + public void shouldNotSubscribeTwiceConditional() { + Throwable[] throwables = new Throwable[1]; + CountDownLatch latch = new CountDownLatch(1); + StepVerifier.create( + Flux.just(1L) + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (s, f) -> { + RaceTestUtils.race( + () -> + f.subscribe( + __ -> {}, + t -> { + throwables[0] = t; + latch.countDown(); + }, + latch::countDown), + () -> + f.subscribe( + __ -> {}, + t -> { + throwables[0] = t; + latch.countDown(); + }, + latch::countDown)); + + return Flux.empty(); + }, + false) + .filter(e -> true))) + .expectSubscription() + .expectComplete() + .verify(); + + Assertions.assertThat(throwables[0]) + .isInstanceOf(IllegalStateException.class) + .hasMessage("FluxSwitchOnFirst allows only one Subscriber"); + } + + @Test + public void shouldNotSubscribeTwiceWhenCanceled() { + CountDownLatch latch = new CountDownLatch(1); + StepVerifier.create( + Flux.just(1L, 2L) + .doOnComplete( + () -> { + try { + latch.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }) + .hide() + .publishOn(Schedulers.parallel()) + .cancelOn(NoOpsScheduler.INSTANCE) + .doOnCancel(latch::countDown) + .transform(flux -> new FluxSwitchOnFirst<>(flux, (s, f) -> f, false)) + .doOnSubscribe( + s -> Schedulers.single().schedule(s::cancel, 10, TimeUnit.MILLISECONDS))) + .expectSubscription() + .expectNext(2L) + .expectNoEvent(Duration.ofMillis(200)) + .thenCancel() + .verifyThenAssertThat() + .hasNotDroppedErrors(); + } + + @Test + public void shouldNotSubscribeTwiceConditionalWhenCanceled() { + CountDownLatch latch = new CountDownLatch(1); + StepVerifier.create( + Flux.just(1L, 2L) + .doOnComplete( + () -> { + try { + latch.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }) + .hide() + .publishOn(Schedulers.parallel()) + .cancelOn(NoOpsScheduler.INSTANCE) + .doOnCancel(latch::countDown) + .transform(flux -> new FluxSwitchOnFirst<>(flux, (s, f) -> f, false)) + .filter(e -> true) + .doOnSubscribe( + s -> Schedulers.single().schedule(s::cancel, 10, TimeUnit.MILLISECONDS))) + .expectSubscription() + .expectNext(2L) + .expectNoEvent(Duration.ofMillis(200)) + .thenCancel() + .verifyThenAssertThat() + .hasNotDroppedErrors(); + } + + @Test + public void shouldSendOnErrorSignalConditional() { + @SuppressWarnings("unchecked") + Signal[] first = new Signal[1]; + + RuntimeException error = new RuntimeException(); + StepVerifier.create( + Flux.error(error) + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (s, f) -> { + first[0] = s; + + return f; + }, + false)) + .filter(e -> true)) + .expectSubscription() + .expectError(RuntimeException.class) + .verify(); + + Assertions.assertThat(first).containsExactly(Signal.error(error)); + } + + @Test + public void shouldSendOnNextSignalConditional() { + @SuppressWarnings("unchecked") + Signal[] first = new Signal[1]; + + StepVerifier.create( + Flux.just(1L) + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (s, f) -> { + first[0] = s; + + return f; + }, + false)) + .filter(e -> true)) + .expectSubscription() + .expectComplete() + .verify(); + + Assertions.assertThat((long) first[0].get()).isEqualTo(1L); + } + + @Test + public void shouldSendOnErrorSignalWithDelaySubscription() { + @SuppressWarnings("unchecked") + Signal[] first = new Signal[1]; + + RuntimeException error = new RuntimeException(); + StepVerifier.create( + Flux.error(error) + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (s, f) -> { + first[0] = s; + + return f.delaySubscription(Duration.ofMillis(100)); + }, + false))) + .expectSubscription() + .expectError(RuntimeException.class) + .verify(); + + Assertions.assertThat(first).containsExactly(Signal.error(error)); + } + + @Test + public void shouldSendOnCompleteSignalWithDelaySubscription() { + @SuppressWarnings("unchecked") + Signal[] first = new Signal[1]; + + StepVerifier.create( + Flux.empty() + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (s, f) -> { + first[0] = s; + + return f.delaySubscription(Duration.ofMillis(100)); + }, + false))) + .expectSubscription() + .expectComplete() + .verify(); + + Assertions.assertThat(first).containsExactly(Signal.complete()); + } + + @Test + public void shouldSendOnErrorSignal() { + @SuppressWarnings("unchecked") + Signal[] first = new Signal[1]; + + RuntimeException error = new RuntimeException(); + StepVerifier.create( + Flux.error(error) + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (s, f) -> { + first[0] = s; + + return f; + }, + false))) + .expectSubscription() + .expectError(RuntimeException.class) + .verify(); + + Assertions.assertThat(first).containsExactly(Signal.error(error)); + } + + @Test + public void shouldSendOnNextSignal() { + @SuppressWarnings("unchecked") + Signal[] first = new Signal[1]; + + StepVerifier.create( + Flux.just(1L) + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (s, f) -> { + first[0] = s; + + return f; + }, + false))) + .expectSubscription() + .expectComplete() + .verify(); + + Assertions.assertThat((long) first[0].get()).isEqualTo(1L); + } + + @Test + public void shouldSendOnNextAsyncSignal() { + @SuppressWarnings("unchecked") + Signal[] first = new Signal[1]; + + StepVerifier.create( + Flux.just(1L) + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (s, f) -> { + first[0] = s; + + return f.subscribeOn(Schedulers.elastic()); + }, + false))) + .expectSubscription() + .expectComplete() + .verify(); + + Assertions.assertThat((long) first[0].get()).isEqualTo(1L); + } + + @Test + public void shouldSendOnNextAsyncSignalConditional() { + @SuppressWarnings("unchecked") + Signal[] first = new Signal[1]; + + StepVerifier.create( + Flux.just(1L) + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (s, f) -> { + first[0] = s; + + return f.subscribeOn(Schedulers.elastic()); + }, + false)) + .filter(p -> true)) + .expectSubscription() + .expectComplete() + .verify(); + + Assertions.assertThat((long) first[0].get()).isEqualTo(1L); + } + + @Test + public void shouldNeverSendIncorrectRequestSizeToUpstream() throws InterruptedException { + TestPublisher publisher = TestPublisher.createCold(); + AtomicLong capture = new AtomicLong(-1); + ArrayList requested = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(1); + Flux switchTransformed = + publisher + .flux() + .doOnRequest(requested::add) + .transform( + flux -> new FluxSwitchOnFirst<>(flux, (first, innerFlux) -> innerFlux, false)); + + publisher.next(1L); + publisher.complete(); + + switchTransformed.subscribe(capture::set, __ -> {}, latch::countDown, s -> s.request(1)); + + latch.await(5, TimeUnit.SECONDS); + + Assertions.assertThat(capture.get()).isEqualTo(-1); + Assertions.assertThat(requested).containsExactly(1L, 1L); + } + + @Test + public void shouldNeverSendIncorrectRequestSizeToUpstreamConditional() + throws InterruptedException { + TestPublisher publisher = TestPublisher.createCold(); + AtomicLong capture = new AtomicLong(-1); + ArrayList requested = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(1); + Flux switchTransformed = + publisher + .flux() + .doOnRequest(e1 -> requested.add(e1)) + .transform( + flux -> new FluxSwitchOnFirst<>(flux, (first, innerFlux) -> innerFlux, false)) + .filter(e -> true); + + publisher.next(1L); + publisher.complete(); + + switchTransformed.subscribe(capture::set, __ -> {}, latch::countDown, s -> s.request(1)); + + latch.await(5, TimeUnit.SECONDS); + + Assertions.assertThat(capture.get()).isEqualTo(-1L); + Assertions.assertThat(requested).containsExactly(1L, 1L); + } + + @Test + public void shouldBeRequestedOneFromUpstreamTwiceInCaseOfConditional() + throws InterruptedException { + TestPublisher publisher = TestPublisher.createCold(); + ArrayList capture = new ArrayList<>(); + ArrayList requested = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(1); + Flux switchTransformed = + publisher + .flux() + .doOnRequest(requested::add) + .transform( + flux -> new FluxSwitchOnFirst<>(flux, (first, innerFlux) -> innerFlux, false)) + .filter(e -> false); + + publisher.next(1L); + publisher.complete(); + + switchTransformed.subscribe(capture::add, __ -> {}, latch::countDown, s -> s.request(1)); + + latch.await(5, TimeUnit.SECONDS); + + Assertions.assertThat(capture).isEmpty(); + Assertions.assertThat(requested).containsExactly(1L, 1L); + } + + @Test + public void shouldBeRequestedExactlyOneAndThenLongMaxValue() throws InterruptedException { + TestPublisher publisher = TestPublisher.createCold(); + ArrayList capture = new ArrayList<>(); + ArrayList requested = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(1); + Flux switchTransformed = + publisher + .flux() + .doOnRequest(requested::add) + .transform( + flux -> new FluxSwitchOnFirst<>(flux, (first, innerFlux) -> innerFlux, false)); + + publisher.next(1L); + publisher.complete(); + + switchTransformed.subscribe(capture::add, __ -> {}, latch::countDown); + + latch.await(5, TimeUnit.SECONDS); + + Assertions.assertThat(capture).isEmpty(); + Assertions.assertThat(requested).containsExactly(1L, Long.MAX_VALUE); + } + + @Test + public void shouldBeRequestedExactlyOneAndThenLongMaxValueConditional() + throws InterruptedException { + TestPublisher publisher = TestPublisher.createCold(); + ArrayList capture = new ArrayList<>(); + ArrayList requested = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(1); + Flux switchTransformed = + publisher + .flux() + .doOnRequest(requested::add) + .transform( + flux -> new FluxSwitchOnFirst<>(flux, (first, innerFlux) -> innerFlux, false)); + + publisher.next(1L); + publisher.complete(); + + switchTransformed.subscribe(capture::add, __ -> {}, latch::countDown); + + latch.await(5, TimeUnit.SECONDS); + + Assertions.assertThat(capture).isEmpty(); + Assertions.assertThat(requested).containsExactly(1L, Long.MAX_VALUE); + } + + @Test + public void shouldReturnCorrectContextOnEmptySource() { + @SuppressWarnings("unchecked") + Signal[] first = new Signal[1]; + + Flux switchTransformed = + Flux.empty() + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (f, innerFlux) -> { + first[0] = f; + return innerFlux; + }, + false)) + .subscriberContext(Context.of("a", "c")) + .subscriberContext(Context.of("c", "d")); + + StepVerifier.create(switchTransformed, 0) + .expectSubscription() + .thenRequest(1) + .expectAccessibleContext() + .contains("a", "c") + .contains("c", "d") + .then() + .expectComplete() + .verify(); + + Assertions.assertThat(first) + .containsExactly(Signal.complete(Context.of("a", "c").put("c", "d"))); + } + + @Test + public void shouldNotFailOnIncorrectPublisherBehavior() { + TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.CLEANUP_ON_TERMINATE); + Flux switchTransformed = + publisher + .flux() + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (first, innerFlux) -> innerFlux.subscriberContext(Context.of("a", "b")), + false)); + + StepVerifier.create( + new Flux() { + @Override + public void subscribe(CoreSubscriber actual) { + switchTransformed.subscribe(actual); + publisher.next(1L); + } + }, + 0) + .thenRequest(1) + .then(() -> publisher.next(2L)) + .expectNext(2L) + .then(() -> publisher.error(new RuntimeException())) + .then(() -> publisher.error(new RuntimeException())) + .then(() -> publisher.error(new RuntimeException())) + .then(() -> publisher.error(new RuntimeException())) + .expectError() + .verifyThenAssertThat() + .hasDroppedErrors(3) + .tookLessThan(Duration.ofSeconds(10)); + + publisher.assertWasRequested(); + publisher.assertNoRequestOverflow(); + } + + @Test + public void shouldBeAbleToAccessUpstreamContext() { + TestPublisher publisher = TestPublisher.createCold(); + + Flux switchTransformed = + publisher + .flux() + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (first, innerFlux) -> + innerFlux.map(String::valueOf).subscriberContext(Context.of("a", "b")), + false)) + .subscriberContext(Context.of("a", "c")) + .subscriberContext(Context.of("c", "d")); + + publisher.next(1L); + publisher.next(2L); + + StepVerifier.create(switchTransformed, 0) + .thenRequest(1) + .expectNext("2") + .thenRequest(1) + .then(() -> publisher.next(3L)) + .expectNext("3") + .expectAccessibleContext() + .contains("a", "b") + .contains("c", "d") + .then() + .then(publisher::complete) + .expectComplete() + .verify(Duration.ofSeconds(10)); + + publisher.assertWasRequested(); + publisher.assertNoRequestOverflow(); + } + + @Test + public void shouldNotHangWhenOneElementUpstream() { + TestPublisher publisher = TestPublisher.createCold(); + + Flux switchTransformed = + publisher + .flux() + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (first, innerFlux) -> + innerFlux.map(String::valueOf).subscriberContext(Context.of("a", "b")), + false)) + .subscriberContext(Context.of("a", "c")) + .subscriberContext(Context.of("c", "d")); + + publisher.next(1L); + publisher.complete(); + + StepVerifier.create(switchTransformed, 0).expectComplete().verify(Duration.ofSeconds(10)); + + publisher.assertWasRequested(); + publisher.assertNoRequestOverflow(); + } + + @Test + public void backpressureTest() { + TestPublisher publisher = TestPublisher.createCold(); + AtomicLong requested = new AtomicLong(); + + Flux switchTransformed = + publisher + .flux() + .doOnRequest(requested::addAndGet) + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, (first, innerFlux) -> innerFlux.map(String::valueOf), false)); + + publisher.next(1L); + + StepVerifier.create(switchTransformed, 0) + .thenRequest(1) + .then(() -> publisher.next(2L)) + .expectNext("2") + .then(publisher::complete) + .expectComplete() + .verify(Duration.ofSeconds(10)); + + publisher.assertWasRequested(); + publisher.assertNoRequestOverflow(); + + Assertions.assertThat(requested.get()).isEqualTo(2L); + } + + @Test + public void backpressureConditionalTest() { + Flux publisher = Flux.range(0, 10000); + AtomicLong requested = new AtomicLong(); + + Flux switchTransformed = + publisher + .doOnRequest(requested::addAndGet) + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, (first, innerFlux) -> innerFlux.map(String::valueOf), false)) + .filter(e -> false); + + StepVerifier.create(switchTransformed, 0) + .thenRequest(1) + .expectComplete() + .verify(Duration.ofSeconds(10)); + + Assertions.assertThat(requested.get()).isEqualTo(2L); + } + + @Test + public void backpressureHiddenConditionalTest() { + Flux publisher = Flux.range(0, 10000); + AtomicLong requested = new AtomicLong(); + + Flux switchTransformed = + publisher + .doOnRequest(requested::addAndGet) + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, (first, innerFlux) -> innerFlux.map(String::valueOf).hide(), false)) + .filter(e -> false); + + StepVerifier.create(switchTransformed, 0) + .thenRequest(1) + .expectComplete() + .verify(Duration.ofSeconds(10)); + + Assertions.assertThat(requested.get()).isEqualTo(10001L); + } + + @Test + public void backpressureDrawbackOnConditionalInTransformTest() { + Flux publisher = Flux.range(0, 10000); + AtomicLong requested = new AtomicLong(); + + Flux switchTransformed = + publisher + .doOnRequest(requested::addAndGet) + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (first, innerFlux) -> innerFlux.map(String::valueOf).filter(e -> false), + false)); + + StepVerifier.create(switchTransformed, 0) + .thenRequest(1) + .expectComplete() + .verify(Duration.ofSeconds(10)); + + Assertions.assertThat(requested.get()).isEqualTo(10001L); + } + + @Test + public void shouldErrorOnOverflowTest() { + TestPublisher publisher = TestPublisher.createCold(); + + Flux switchTransformed = + publisher + .flux() + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, (first, innerFlux) -> innerFlux.map(String::valueOf), false)); + + publisher.next(1L); + publisher.next(2L); + + StepVerifier.create(switchTransformed, 0) + .thenRequest(1) + .expectNext("2") + .then(() -> publisher.next(2L)) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Can't deliver value due to lack of requests")) + .verify(Duration.ofSeconds(10)); + + publisher.assertWasRequested(); + publisher.assertNoRequestOverflow(); + } + + @Test + public void shouldPropagateonCompleteCorrectly() { + Flux switchTransformed = + Flux.empty() + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, (first, innerFlux) -> innerFlux.map(String::valueOf), false)); + + StepVerifier.create(switchTransformed).expectComplete().verify(Duration.ofSeconds(10)); + } + + @Test + public void shouldPropagateOnCompleteWithMergedElementsCorrectly() { + Flux switchTransformed = + Flux.empty() + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (first, innerFlux) -> + innerFlux.map(String::valueOf).mergeWith(Flux.just("1", "2", "3")), + false)); + + StepVerifier.create(switchTransformed) + .expectNext("1", "2", "3") + .expectComplete() + .verify(Duration.ofSeconds(10)); + } + + @Test + public void shouldPropagateErrorCorrectly() { + Flux switchTransformed = + Flux.error(new RuntimeException("hello")) + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, (first, innerFlux) -> innerFlux.map(String::valueOf), true)); + + StepVerifier.create(switchTransformed) + .expectErrorMessage("hello") + .verify(Duration.ofSeconds(10)); + } + + @Test + public void shouldBeAbleToBeCancelledProperly() { + TestPublisher publisher = TestPublisher.createCold(); + Flux switchTransformed = + publisher + .flux() + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, (first, innerFlux) -> innerFlux.map(String::valueOf), false)); + + publisher.next(1); + + StepVerifier.create(switchTransformed, 0).thenCancel().verify(Duration.ofSeconds(10)); + + publisher.assertCancelled(); + publisher.assertWasRequested(); + } + + @Test + public void shouldBeAbleToBeCancelledProperly2() { + TestPublisher publisher = TestPublisher.createCold(); + Flux switchTransformed = + publisher + .flux() + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, (first, innerFlux) -> innerFlux.map(String::valueOf).take(1), false)); + + publisher.next(1); + publisher.next(2); + publisher.next(3); + publisher.next(4); + + StepVerifier.create(switchTransformed, 1) + .expectNext("2") + .expectComplete() + .verify(Duration.ofSeconds(10)); + + publisher.assertCancelled(); + publisher.assertWasRequested(); + } + + @Test + public void shouldBeAbleToBeCancelledProperly3() { + TestPublisher publisher = TestPublisher.createCold(); + Flux switchTransformed = + publisher + .flux() + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, (first, innerFlux) -> innerFlux.map(String::valueOf), false)) + .take(1); + + publisher.next(1); + publisher.next(2); + publisher.next(3); + publisher.next(4); + + StepVerifier.create(switchTransformed, 1) + .expectNext("2") + .expectComplete() + .verify(Duration.ofSeconds(10)); + + publisher.assertCancelled(); + publisher.assertWasRequested(); + } + + @Test + public void shouldReturnNormallyIfExceptionIsThrownOnNextDuringSwitching() { + @SuppressWarnings("unchecked") + Signal[] first = new Signal[1]; + + Optional expectedCause = Optional.of(1L); + + StepVerifier.create( + Flux.just(1L) + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (s, f) -> { + first[0] = s; + throw new NullPointerException(); + }, + false))) + .expectSubscription() + .expectError(NullPointerException.class) + .verifyThenAssertThat() + .hasOperatorErrorsSatisfying( + c -> + Assertions.assertThat(c) + .hasOnlyOneElementSatisfying( + t -> { + Assertions.assertThat(t.getT1()) + .containsInstanceOf(NullPointerException.class); + Assertions.assertThat(t.getT2()).isEqualTo(expectedCause); + })); + + Assertions.assertThat((long) first[0].get()).isEqualTo(1L); + } + + @Test + public void shouldReturnNormallyIfExceptionIsThrownOnErrorDuringSwitching() { + @SuppressWarnings("unchecked") + Signal[] first = new Signal[1]; + + NullPointerException npe = new NullPointerException(); + RuntimeException error = new RuntimeException(); + StepVerifier.create( + Flux.error(error) + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (s, f) -> { + first[0] = s; + throw npe; + }, + false))) + .expectSubscription() + .verifyError(NullPointerException.class); + + Assertions.assertThat(first).containsExactly(Signal.error(error)); + } + + @Test + public void shouldReturnNormallyIfExceptionIsThrownOnCompleteDuringSwitching() { + @SuppressWarnings("unchecked") + Signal[] first = new Signal[1]; + + StepVerifier.create( + Flux.empty() + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (s, f) -> { + first[0] = s; + throw new NullPointerException(); + }, + false))) + .expectSubscription() + .expectError(NullPointerException.class) + .verifyThenAssertThat() + .hasOperatorErrorMatching( + t -> { + Assertions.assertThat(t).isInstanceOf(NullPointerException.class); + return true; + }); + + Assertions.assertThat(first).containsExactly(Signal.complete()); + } + + @Test + public void sourceSubscribedOnce() { + AtomicInteger subCount = new AtomicInteger(); + Flux source = + Flux.range(1, 10).hide().doOnSubscribe(subscription -> subCount.incrementAndGet()); + + StepVerifier.create( + source.transform( + flux -> + new FluxSwitchOnFirst<>( + flux, (s, f) -> f.filter(v -> v % 2 == s.get()), false))) + .expectNext(3, 5, 7, 9) + .verifyComplete(); + + Assertions.assertThat(subCount).hasValue(1); + } + + @Test + public void checkHotSource() { + ReplayProcessor processor = ReplayProcessor.create(1); + + processor.onNext(1L); + processor.onNext(2L); + processor.onNext(3L); + + StepVerifier.create( + processor.transform( + flux -> + new FluxSwitchOnFirst<>( + flux, (s, f) -> f.filter(v -> v % s.get() == 0), false))) + .then( + () -> { + processor.onNext(4L); + processor.onNext(5L); + processor.onNext(6L); + processor.onNext(7L); + processor.onNext(8L); + processor.onNext(9L); + processor.onComplete(); + }) + .expectNext(6L, 9L) + .verifyComplete(); + } + + @Test + public void shouldCancelSourceOnUnrelatedPublisherComplete() { + EmitterProcessor testPublisher = EmitterProcessor.create(); + + testPublisher.onNext(1L); + + StepVerifier.create( + testPublisher.transform( + flux -> new FluxSwitchOnFirst<>(flux, (s, f) -> Flux.empty(), true))) + .expectSubscription() + .verifyComplete(); + + Assertions.assertThat(testPublisher.isCancelled()).isTrue(); + } + + @Test + public void shouldCancelSourceOnUnrelatedPublisherError() { + EmitterProcessor testPublisher = EmitterProcessor.create(); + + testPublisher.onNext(1L); + + StepVerifier.create( + testPublisher.transform( + flux -> + new FluxSwitchOnFirst<>( + flux, (s, f) -> Flux.error(new RuntimeException("test")), false))) + .expectSubscription() + .verifyErrorSatisfies( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isExactlyInstanceOf(RuntimeException.class)); + + Assertions.assertThat(testPublisher.isCancelled()).isTrue(); + } + + @Test + public void shouldCancelSourceOnUnrelatedPublisherCancel() { + TestPublisher testPublisher = TestPublisher.create(); + + StepVerifier.create( + testPublisher + .flux() + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, (s, f) -> Flux.error(new RuntimeException("test")), false))) + .expectSubscription() + .thenCancel() + .verify(); + + Assertions.assertThat(testPublisher.wasCancelled()).isTrue(); + } + + @Test + public void shouldCancelUpstreamBeforeFirst() { + EmitterProcessor testPublisher = EmitterProcessor.create(); + + StepVerifier.create( + testPublisher.transform( + flux -> + new FluxSwitchOnFirst<>( + flux, (s, f) -> Flux.error(new RuntimeException("test")), false))) + .thenAwait(Duration.ofMillis(50)) + .thenCancel() + .verify(Duration.ofSeconds(2)); + + Assertions.assertThat(testPublisher.isCancelled()).isTrue(); + } + + @Test + public void shouldContinueWorkingRegardlessTerminalOnDownstream() { + TestPublisher testPublisher = TestPublisher.create(); + + Flux[] intercepted = new Flux[1]; + + StepVerifier.create( + testPublisher + .flux() + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (s, f) -> { + intercepted[0] = f; + return Flux.just(2L); + }, + false))) + .expectSubscription() + .then(() -> testPublisher.next(1L)) + .expectNext(2L) + .expectComplete() + .verify(Duration.ofSeconds(2)); + + Assertions.assertThat(testPublisher.wasCancelled()).isFalse(); + + StepVerifier.create(intercepted[0]) + .expectSubscription() + .then(testPublisher::complete) + .expectComplete() + .verify(Duration.ofSeconds(1)); + } + + @Test + public void shouldCancelSourceOnOnDownstreamTerminal() { + TestPublisher testPublisher = TestPublisher.create(); + + StepVerifier.create( + testPublisher + .flux() + .transform(flux -> new FluxSwitchOnFirst<>(flux, (s, f) -> Flux.just(1L), true))) + .expectSubscription() + .then(() -> testPublisher.next(1L)) + .expectNext(1L) + .expectComplete() + .verify(Duration.ofSeconds(2)); + + Assertions.assertThat(testPublisher.wasCancelled()).isTrue(); + } + + @Test + public void racingTest() { + for (int i = 0; i < 1000; i++) { + CoreSubscriber[] subscribers = new CoreSubscriber[1]; + Subscription[] downstreamSubscriptions = new Subscription[1]; + Subscription[] innerSubscriptions = new Subscription[1]; + + AtomicLong requested = new AtomicLong(); + + Flux.just(2) + .doOnRequest(requested::addAndGet) + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (s, f) -> + new Flux() { + + @Override + public void subscribe(CoreSubscriber actual) { + subscribers[0] = actual; + f.subscribe( + actual::onNext, + actual::onError, + actual::onComplete, + (s) -> innerSubscriptions[0] = s); + } + }, + false)) + .subscribe(__ -> {}, __ -> {}, () -> {}, s -> downstreamSubscriptions[0] = s); + + CoreSubscriber subscriber = subscribers[0]; + Subscription downstreamSubscription = downstreamSubscriptions[0]; + Subscription innerSubscription = innerSubscriptions[0]; + innerSubscription.request(1); + + RaceTestUtils.race( + () -> subscriber.onSubscribe(innerSubscription), () -> downstreamSubscription.request(1)); + + Assertions.assertThat(requested.get()).isEqualTo(3); + } + } + + @Test + public void racingConditionalTest() { + for (int i = 0; i < 1000; i++) { + CoreSubscriber[] subscribers = new CoreSubscriber[1]; + Subscription[] downstreamSubscriptions = new Subscription[1]; + Subscription[] innerSubscriptions = new Subscription[1]; + + AtomicLong requested = new AtomicLong(); + + Flux.just(2) + .doOnRequest(requested::addAndGet) + .transform( + flux -> + new FluxSwitchOnFirst<>( + flux, + (s, f) -> + new Flux() { + + @Override + public void subscribe(CoreSubscriber actual) { + subscribers[0] = actual; + f.subscribe( + new Fuseable.ConditionalSubscriber() { + @Override + public boolean tryOnNext(Integer integer) { + return ((Fuseable.ConditionalSubscriber) actual) + .tryOnNext(integer); + } + + @Override + public void onSubscribe(Subscription s) { + innerSubscriptions[0] = s; + } + + @Override + public void onNext(Integer integer) { + actual.onNext(integer); + } + + @Override + public void onError(Throwable throwable) { + actual.onError(throwable); + } + + @Override + public void onComplete() { + actual.onComplete(); + } + }); + } + }, + false)) + .filter(__ -> true) + .subscribe(__ -> {}, __ -> {}, () -> {}, s -> downstreamSubscriptions[0] = s); + + CoreSubscriber subscriber = subscribers[0]; + Subscription downstreamSubscription = downstreamSubscriptions[0]; + Subscription innerSubscription = innerSubscriptions[0]; + innerSubscription.request(1); + + RaceTestUtils.race( + () -> subscriber.onSubscribe(innerSubscription), () -> downstreamSubscription.request(1)); + + Assertions.assertThat(requested.get()).isEqualTo(3); + } + } + + private static final class NoOpsScheduler implements Scheduler { + + static final NoOpsScheduler INSTANCE = new NoOpsScheduler(); + + private NoOpsScheduler() {} + + @Override + public Disposable schedule(Runnable task) { + return Disposables.composite(); + } + + @Override + public Worker createWorker() { + return NoOpsWorker.INSTANCE; + } + + static final class NoOpsWorker implements Worker { + + static final NoOpsWorker INSTANCE = new NoOpsWorker(); + + @Override + public Disposable schedule(Runnable task) { + return Disposables.never(); + } + + @Override + public void dispose() {} + }; + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java index 1c00b0502..56ea60feb 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java @@ -219,7 +219,8 @@ default void requestChannel2_000_000() { default void requestChannel3() { AtomicLong requested = new AtomicLong(); Flux payloads = - Flux.range(0, 3).doOnRequest(requested::addAndGet).map(this::createTestPayload); + Flux.range(0, 3).doOnRequest(requested::addAndGet).map(this::createTestPayload); + getClient() .requestChannel(payloads) From 781819ecf69900f2cd7d81d0e3e3c89e2095e34d Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Fri, 3 Apr 2020 22:07:48 +0300 Subject: [PATCH 03/11] more fixes Signed-off-by: Oleh Dokuka --- .../rsocket/internal/FluxSwitchOnFirst.java | 657 --------- .../RateLimitableRequestSubscriber.java | 239 ---- .../internal/FluxSwitchOnFirstTest.java | 1274 ----------------- .../tcp/resume/ResumeFileTransfer.java | 6 +- 4 files changed, 5 insertions(+), 2171 deletions(-) delete mode 100644 rsocket-core/src/main/java/io/rsocket/internal/FluxSwitchOnFirst.java delete mode 100755 rsocket-core/src/main/java/io/rsocket/internal/RateLimitableRequestSubscriber.java delete mode 100644 rsocket-core/src/test/java/io/rsocket/internal/FluxSwitchOnFirstTest.java diff --git a/rsocket-core/src/main/java/io/rsocket/internal/FluxSwitchOnFirst.java b/rsocket-core/src/main/java/io/rsocket/internal/FluxSwitchOnFirst.java deleted file mode 100644 index 9a6ef4214..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/FluxSwitchOnFirst.java +++ /dev/null @@ -1,657 +0,0 @@ -/* - * Copyright (c) 2011-2018 Pivotal Software Inc, All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.internal; - -import java.util.Objects; -import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; -import java.util.concurrent.atomic.AtomicLongFieldUpdater; -import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; -import java.util.function.BiFunction; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscription; -import reactor.core.CoreSubscriber; -import reactor.core.Fuseable; -import reactor.core.Scannable; -import reactor.core.publisher.Flux; -import reactor.core.publisher.FluxOperator; -import reactor.core.publisher.Operators; -import reactor.core.publisher.Signal; -import reactor.util.annotation.Nullable; -import reactor.util.context.Context; - -/** - * @author Oleh Dokuka - * @param - * @param - */ -public final class FluxSwitchOnFirst extends FluxOperator { - static final int STATE_CANCELLED = -2; - static final int STATE_REQUESTED = -1; - - static final int STATE_INIT = 0; - static final int STATE_SUBSCRIBED_ONCE = 1; - - final BiFunction, Flux, Publisher> transformer; - final boolean cancelSourceOnComplete; - - volatile int once; - - @SuppressWarnings("rawtypes") - static final AtomicIntegerFieldUpdater ONCE = - AtomicIntegerFieldUpdater.newUpdater(FluxSwitchOnFirst.class, "once"); - - public FluxSwitchOnFirst( - Flux source, - BiFunction, Flux, Publisher> transformer, - boolean cancelSourceOnComplete) { - super(source); - this.transformer = Objects.requireNonNull(transformer, "transformer"); - this.cancelSourceOnComplete = cancelSourceOnComplete; - } - - @Override - public int getPrefetch() { - return 1; - } - - @Override - @SuppressWarnings("unchecked") - public void subscribe(CoreSubscriber actual) { - Objects.requireNonNull(actual, "subscribe"); - if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { - if (actual instanceof Fuseable.ConditionalSubscriber) { - source.subscribe( - new SwitchOnFirstConditionalMain<>( - (Fuseable.ConditionalSubscriber) actual, - transformer, - cancelSourceOnComplete)); - return; - } - source.subscribe(new SwitchOnFirstMain<>(actual, transformer, cancelSourceOnComplete)); - } else { - Operators.error(actual, new IllegalStateException("Allows only a single Subscriber")); - } - } - - abstract static class AbstractSwitchOnFirstMain extends Flux - implements CoreSubscriber, Scannable, Subscription { - - final ControlSubscriber outer; - final BiFunction, Flux, Publisher> transformer; - - Subscription s; - Throwable throwable; - boolean first = true; - boolean done; - - volatile CoreSubscriber inner; - - @SuppressWarnings("rawtypes") - static final AtomicReferenceFieldUpdater INNER = - AtomicReferenceFieldUpdater.newUpdater( - AbstractSwitchOnFirstMain.class, CoreSubscriber.class, "inner"); - - volatile int state; - - @SuppressWarnings("rawtypes") - static final AtomicIntegerFieldUpdater STATE = - AtomicIntegerFieldUpdater.newUpdater(AbstractSwitchOnFirstMain.class, "state"); - - @SuppressWarnings("unchecked") - AbstractSwitchOnFirstMain( - CoreSubscriber outer, - BiFunction, Flux, Publisher> transformer, - boolean cancelSourceOnComplete) { - this.outer = - outer instanceof Fuseable.ConditionalSubscriber - ? new SwitchOnFirstConditionalControlSubscriber<>( - this, (Fuseable.ConditionalSubscriber) outer, cancelSourceOnComplete) - : new SwitchOnFirstControlSubscriber<>(this, outer, cancelSourceOnComplete); - this.transformer = transformer; - } - - @Override - @Nullable - public Object scanUnsafe(Attr key) { - CoreSubscriber i = this.inner; - - if (key == Attr.CANCELLED) return i == CancelledSubscriber.INSTANCE; - if (key == Attr.TERMINATED) return done || i == CancelledSubscriber.INSTANCE; - - return null; - } - - @Override - public Context currentContext() { - CoreSubscriber actual = inner; - - if (actual != null) { - return actual.currentContext(); - } - - return outer.currentContext(); - } - - @Override - public void cancel() { - if (INNER.getAndSet(this, CancelledSubscriber.INSTANCE) != CancelledSubscriber.INSTANCE) { - s.cancel(); - } - } - - @Override - public void onSubscribe(Subscription s) { - if (Operators.validate(this.s, s)) { - this.s = s; - outer.sendSubscription(); - s.request(1); - } - } - - @Override - public void onNext(T t) { - CoreSubscriber i = inner; - if (i == CancelledSubscriber.INSTANCE || done) { - Operators.onNextDropped(t, currentContext()); - return; - } - - if (i == null) { - Publisher result; - CoreSubscriber o = outer; - - try { - result = - Objects.requireNonNull( - transformer.apply(Signal.next(t, o.currentContext()), this), - "The transformer returned a null value"); - } catch (Throwable e) { - done = true; - Operators.error(o, Operators.onOperatorError(s, e, t, o.currentContext())); - return; - } - - first = false; - result.subscribe(o); - return; - } - - i.onNext(t); - } - - @Override - public void onError(Throwable t) { - CoreSubscriber i = inner; - if (i == CancelledSubscriber.INSTANCE || done) { - Operators.onErrorDropped(t, currentContext()); - return; - } - - throwable = t; - done = true; - - if (first && i == null) { - Publisher result; - CoreSubscriber o = outer; - - try { - result = - Objects.requireNonNull( - transformer.apply(Signal.error(t, o.currentContext()), this), - "The transformer returned a null value"); - } catch (Throwable e) { - done = true; - Operators.error(o, Operators.onOperatorError(s, e, t, o.currentContext())); - return; - } - - first = false; - result.subscribe(o); - } - - i = this.inner; - if (i != null) { - i.onError(t); - } - } - - @Override - public void onComplete() { - CoreSubscriber i = inner; - if (i == CancelledSubscriber.INSTANCE || done) { - return; - } - - done = true; - - if (first && i == null) { - Publisher result; - CoreSubscriber o = outer; - - try { - result = - Objects.requireNonNull( - transformer.apply(Signal.complete(o.currentContext()), this), - "The transformer returned a null value"); - } catch (Throwable e) { - done = true; - Operators.error(o, Operators.onOperatorError(s, e, null, o.currentContext())); - return; - } - - first = false; - result.subscribe(o); - } - - i = inner; - if (i != null) { - i.onComplete(); - } - } - - @Override - public void request(long n) { - if (Operators.validate(n)) { - s.request(n); - } - } - } - - static final class SwitchOnFirstMain extends AbstractSwitchOnFirstMain { - - SwitchOnFirstMain( - CoreSubscriber outer, - BiFunction, Flux, Publisher> transformer, - boolean cancelSourceOnComplete) { - super(outer, transformer, cancelSourceOnComplete); - } - - @Override - public void subscribe(CoreSubscriber actual) { - if (state == STATE_INIT && STATE.compareAndSet(this, STATE_INIT, STATE_SUBSCRIBED_ONCE)) { - if (done) { - if (throwable != null) { - Operators.error(actual, throwable); - } else { - Operators.complete(actual); - } - return; - } - INNER.lazySet(this, actual); - actual.onSubscribe(this); - } else { - Operators.error( - actual, new IllegalStateException("FluxSwitchOnFirst allows only one Subscriber")); - } - } - } - - static final class SwitchOnFirstConditionalMain extends AbstractSwitchOnFirstMain - implements Fuseable.ConditionalSubscriber { - - SwitchOnFirstConditionalMain( - Fuseable.ConditionalSubscriber outer, - BiFunction, Flux, Publisher> transformer, - boolean cancelSourceOnComplete) { - super(outer, transformer, cancelSourceOnComplete); - } - - @Override - public void subscribe(CoreSubscriber actual) { - if (state == STATE_INIT && STATE.compareAndSet(this, STATE_INIT, STATE_SUBSCRIBED_ONCE)) { - if (done) { - if (throwable != null) { - Operators.error(actual, throwable); - } else { - Operators.complete(actual); - } - return; - } - INNER.lazySet(this, Operators.toConditionalSubscriber(actual)); - actual.onSubscribe(this); - } else { - Operators.error( - actual, new IllegalStateException("FluxSwitchOnFirst allows only one Subscriber")); - } - } - - @Override - @SuppressWarnings("unchecked") - public boolean tryOnNext(T t) { - CoreSubscriber i = inner; - if (i == CancelledSubscriber.INSTANCE || done) { - Operators.onNextDropped(t, currentContext()); - return false; - } - - if (i == null) { - Publisher result; - CoreSubscriber o = outer; - - try { - result = - Objects.requireNonNull( - transformer.apply(Signal.next(t, o.currentContext()), this), - "The transformer returned a null value"); - } catch (Throwable e) { - done = true; - Operators.error(o, Operators.onOperatorError(s, e, t, o.currentContext())); - return false; - } - - first = false; - result.subscribe(o); - return true; - } - - return ((Fuseable.ConditionalSubscriber) i).tryOnNext(t); - } - } - - static final class SwitchOnFirstControlSubscriber - implements ControlSubscriber, Scannable, Subscription { - - final AbstractSwitchOnFirstMain parent; - final CoreSubscriber actual; - final boolean cancelSourceOnComplete; - - volatile long requested; - - @SuppressWarnings("rawtypes") - static final AtomicLongFieldUpdater REQUESTED = - AtomicLongFieldUpdater.newUpdater(SwitchOnFirstControlSubscriber.class, "requested"); - - Subscription s; - - SwitchOnFirstControlSubscriber( - AbstractSwitchOnFirstMain parent, - CoreSubscriber actual, - boolean cancelSourceOnComplete) { - this.parent = parent; - this.actual = actual; - this.cancelSourceOnComplete = cancelSourceOnComplete; - } - - @Override - public Context currentContext() { - return actual.currentContext(); - } - - @Override - public void sendSubscription() { - actual.onSubscribe(this); - } - - @Override - public void onSubscribe(Subscription s) { - if (this.s == null && this.requested != STATE_CANCELLED) { - this.s = s; - - tryRequest(); - } else { - s.cancel(); - } - } - - @Override - public void onNext(T t) { - actual.onNext(t); - } - - @Override - public void onError(Throwable throwable) { - if (!parent.done) { - parent.cancel(); - } - - actual.onError(throwable); - } - - @Override - public void onComplete() { - if (!parent.done && cancelSourceOnComplete) { - parent.cancel(); - } - - actual.onComplete(); - } - - @Override - public void request(long n) { - long r = this.requested; - - if (r > STATE_REQUESTED) { - long u; - for (; ; ) { - if (r == Long.MAX_VALUE) { - return; - } - u = Operators.addCap(r, n); - if (REQUESTED.compareAndSet(this, r, u)) { - return; - } else { - r = requested; - - if (r < 0) { - break; - } - } - } - } - - if (r == STATE_CANCELLED) { - return; - } - - s.request(n); - } - - void tryRequest() { - final Subscription s = this.s; - long r = REQUESTED.getAndSet(this, -1); - - if (r > 0) { - s.request(r); - } - } - - @Override - public void cancel() { - final long state = REQUESTED.getAndSet(this, STATE_CANCELLED); - - if (state == STATE_CANCELLED) { - return; - } - - if (state == STATE_REQUESTED) { - s.cancel(); - } - - parent.cancel(); - } - - @Override - public Object scanUnsafe(Attr key) { - if (key == Attr.PARENT) return parent; - if (key == Attr.ACTUAL) return actual; - - return null; - } - } - - static final class SwitchOnFirstConditionalControlSubscriber - implements ControlSubscriber, Scannable, Subscription, Fuseable.ConditionalSubscriber { - - final AbstractSwitchOnFirstMain parent; - final Fuseable.ConditionalSubscriber actual; - final boolean terminateUpstreamOnComplete; - - volatile long requested; - - @SuppressWarnings("rawtypes") - static final AtomicLongFieldUpdater REQUESTED = - AtomicLongFieldUpdater.newUpdater( - SwitchOnFirstConditionalControlSubscriber.class, "requested"); - - Subscription s; - - SwitchOnFirstConditionalControlSubscriber( - AbstractSwitchOnFirstMain parent, - Fuseable.ConditionalSubscriber actual, - boolean terminateUpstreamOnComplete) { - this.parent = parent; - this.actual = actual; - this.terminateUpstreamOnComplete = terminateUpstreamOnComplete; - } - - @Override - public void sendSubscription() { - actual.onSubscribe(this); - } - - @Override - public Context currentContext() { - return actual.currentContext(); - } - - @Override - public void onSubscribe(Subscription s) { - if (this.s == null && this.requested != STATE_CANCELLED) { - this.s = s; - - tryRequest(); - } else { - s.cancel(); - } - } - - @Override - public void onNext(T t) { - actual.onNext(t); - } - - @Override - public boolean tryOnNext(T t) { - return actual.tryOnNext(t); - } - - @Override - public void onError(Throwable throwable) { - if (!parent.done) { - parent.cancel(); - } - - actual.onError(throwable); - } - - @Override - public void onComplete() { - if (!parent.done && terminateUpstreamOnComplete) { - parent.cancel(); - } - - actual.onComplete(); - } - - @Override - public void request(long n) { - long r = this.requested; - - if (r > STATE_REQUESTED) { - long u; - for (; ; ) { - if (r == Long.MAX_VALUE) { - return; - } - u = Operators.addCap(r, n); - if (REQUESTED.compareAndSet(this, r, u)) { - return; - } else { - r = requested; - - if (r < 0) { - break; - } - } - } - } - - if (r == STATE_CANCELLED) { - return; - } - - s.request(n); - } - - void tryRequest() { - final Subscription s = this.s; - long r = REQUESTED.getAndSet(this, -1); - - if (r > 0) { - s.request(r); - } - } - - @Override - public void cancel() { - final long state = REQUESTED.getAndSet(this, STATE_CANCELLED); - - if (state == STATE_CANCELLED) { - return; - } - - if (state == STATE_REQUESTED) { - s.cancel(); - return; - } - - parent.cancel(); - } - - @Override - public Object scanUnsafe(Attr key) { - if (key == Attr.PARENT) return parent; - if (key == Attr.ACTUAL) return actual; - - return null; - } - } - - interface ControlSubscriber extends CoreSubscriber { - - void sendSubscription(); - } - - static final class CancelledSubscriber implements CoreSubscriber { - - static final CancelledSubscriber INSTANCE = new CancelledSubscriber(); - - private CancelledSubscriber() {} - - @Override - public void onSubscribe(Subscription s) {} - - @Override - public void onNext(Object o) {} - - @Override - public void onError(Throwable t) {} - - @Override - public void onComplete() {} - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/RateLimitableRequestSubscriber.java b/rsocket-core/src/main/java/io/rsocket/internal/RateLimitableRequestSubscriber.java deleted file mode 100755 index 2fb18e1f7..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/RateLimitableRequestSubscriber.java +++ /dev/null @@ -1,239 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.internal; - -import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; -import org.reactivestreams.Subscription; -import reactor.core.CoreSubscriber; -import reactor.core.Exceptions; -import reactor.core.publisher.Operators; -import reactor.core.publisher.SignalType; - -/** */ -public abstract class RateLimitableRequestSubscriber implements CoreSubscriber, Subscription { - - private final long prefetch; - private final long limit; - - private long externalRequested; // need sync - private int pendingToFulfil; // need sync since should be checked/zerroed in onNext - // and increased in request - private int deliveredElements; // no need to sync since increased zerroed only in - // the request method - - private volatile Subscription subscription; - static final AtomicReferenceFieldUpdater S = - AtomicReferenceFieldUpdater.newUpdater( - RateLimitableRequestSubscriber.class, Subscription.class, "subscription"); - - public RateLimitableRequestSubscriber(long prefetch) { - this.prefetch = prefetch; - this.limit = prefetch == Integer.MAX_VALUE ? Integer.MAX_VALUE : (prefetch - (prefetch >> 2)); - } - - protected void hookOnSubscribe(Subscription s) { - // NO-OP - } - - protected void hookOnNext(T value) { - // NO-OP - } - - protected void hookOnComplete() { - // NO-OP - } - - protected void hookOnError(Throwable throwable) { - throw Exceptions.errorCallbackNotImplemented(throwable); - } - - protected void hookOnCancel() { - // NO-OP - } - - protected void hookFinally(SignalType type) { - // NO-OP - } - - void safeHookFinally(SignalType type) { - try { - hookFinally(type); - } catch (Throwable finallyFailure) { - Operators.onErrorDropped(finallyFailure, currentContext()); - } - } - - @Override - public final void onSubscribe(Subscription s) { - if (Operators.validate(subscription, s)) { - this.subscription = s; - try { - hookOnSubscribe(s); - requestN(); - } catch (Throwable throwable) { - onError(Operators.onOperatorError(s, throwable, currentContext())); - } - } - } - - @Override - public final void onNext(T t) { - try { - hookOnNext(t); - - if (prefetch == Integer.MAX_VALUE) { - return; - } - - final long l = limit; - int d = deliveredElements + 1; - - if (d == l) { - d = 0; - final long r; - final Subscription s = subscription; - - if (s == null) { - return; - } - - synchronized (this) { - long er = externalRequested; - - if (er >= l) { - er -= l; - // keep pendingToFulfil as is since it is eq to prefetch - r = l; - } else { - pendingToFulfil -= l; - if (er > 0) { - r = er; - er = 0; - pendingToFulfil += r; - } else { - r = 0; - } - } - - externalRequested = er; - } - - if (r > 0) { - s.request(r); - } - } - - deliveredElements = d; - } catch (Throwable e) { - onError(e); - } - } - - @Override - public final void onError(Throwable t) { - Subscription s = S.getAndSet(this, Operators.cancelledSubscription()); - if (s == Operators.cancelledSubscription()) { - Operators.onErrorDropped(t, this.currentContext()); - return; - } - - try { - hookOnError(t); - } catch (Throwable e) { - e = Exceptions.addSuppressed(e, t); - Operators.onErrorDropped(e, currentContext()); - } finally { - safeHookFinally(SignalType.ON_ERROR); - } - } - - @Override - public final void onComplete() { - if (S.getAndSet(this, Operators.cancelledSubscription()) != Operators.cancelledSubscription()) { - // we're sure it has not been concurrently cancelled - try { - hookOnComplete(); - } catch (Throwable throwable) { - // onError itself will short-circuit due to the CancelledSubscription being set above - hookOnError(Operators.onOperatorError(throwable, currentContext())); - } finally { - safeHookFinally(SignalType.ON_COMPLETE); - } - } - } - - @Override - public final void request(long n) { - synchronized (this) { - long requested = externalRequested; - if (requested == Long.MAX_VALUE) { - return; - } - externalRequested = Operators.addCap(n, requested); - } - - requestN(); - } - - private void requestN() { - final long r; - final Subscription s = subscription; - - if (s == null) { - return; - } - - synchronized (this) { - final long er = externalRequested; - final long p = prefetch; - final int pendingFulfil = pendingToFulfil; - - if (er != Long.MAX_VALUE || p != Integer.MAX_VALUE) { - // shortcut - if (pendingFulfil == p) { - return; - } - - r = Math.min(p - pendingFulfil, er); - if (er != Long.MAX_VALUE) { - externalRequested -= r; - } - if (p != Integer.MAX_VALUE) { - pendingToFulfil += r; - } - } else { - r = Long.MAX_VALUE; - } - } - - if (r > 0) { - s.request(r); - } - } - - public final void cancel() { - if (Operators.terminate(S, this)) { - try { - hookOnCancel(); - } catch (Throwable throwable) { - hookOnError(Operators.onOperatorError(subscription, throwable, currentContext())); - } finally { - safeHookFinally(SignalType.CANCEL); - } - } - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/FluxSwitchOnFirstTest.java b/rsocket-core/src/test/java/io/rsocket/internal/FluxSwitchOnFirstTest.java deleted file mode 100644 index 7628a5304..000000000 --- a/rsocket-core/src/test/java/io/rsocket/internal/FluxSwitchOnFirstTest.java +++ /dev/null @@ -1,1274 +0,0 @@ -/* - * Copyright (c) 2011-2018 Pivotal Software Inc, All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.internal; - -import java.time.Duration; -import java.util.ArrayList; -import java.util.Optional; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import org.assertj.core.api.Assertions; -import org.junit.Test; -import org.reactivestreams.Subscription; -import reactor.core.CoreSubscriber; -import reactor.core.Disposable; -import reactor.core.Disposables; -import reactor.core.Fuseable; -import reactor.core.publisher.EmitterProcessor; -import reactor.core.publisher.Flux; -import reactor.core.publisher.ReplayProcessor; -import reactor.core.publisher.Signal; -import reactor.core.scheduler.Scheduler; -import reactor.core.scheduler.Schedulers; -import reactor.test.StepVerifier; -import reactor.test.publisher.TestPublisher; -import reactor.test.util.RaceTestUtils; -import reactor.util.context.Context; - -public class FluxSwitchOnFirstTest { - - @Test - public void shouldNotSubscribeTwice() { - Throwable[] throwables = new Throwable[1]; - CountDownLatch latch = new CountDownLatch(1); - StepVerifier.create( - Flux.just(1L) - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (s, f) -> { - RaceTestUtils.race( - () -> - f.subscribe( - __ -> {}, - t -> { - throwables[0] = t; - latch.countDown(); - }, - latch::countDown), - () -> - f.subscribe( - __ -> {}, - t -> { - throwables[0] = t; - latch.countDown(); - }, - latch::countDown)); - - return Flux.empty(); - }, - false))) - .expectSubscription() - .expectComplete() - .verify(); - - Assertions.assertThat(throwables[0]) - .isInstanceOf(IllegalStateException.class) - .hasMessage("FluxSwitchOnFirst allows only one Subscriber"); - } - - @Test - public void shouldNotSubscribeTwiceConditional() { - Throwable[] throwables = new Throwable[1]; - CountDownLatch latch = new CountDownLatch(1); - StepVerifier.create( - Flux.just(1L) - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (s, f) -> { - RaceTestUtils.race( - () -> - f.subscribe( - __ -> {}, - t -> { - throwables[0] = t; - latch.countDown(); - }, - latch::countDown), - () -> - f.subscribe( - __ -> {}, - t -> { - throwables[0] = t; - latch.countDown(); - }, - latch::countDown)); - - return Flux.empty(); - }, - false) - .filter(e -> true))) - .expectSubscription() - .expectComplete() - .verify(); - - Assertions.assertThat(throwables[0]) - .isInstanceOf(IllegalStateException.class) - .hasMessage("FluxSwitchOnFirst allows only one Subscriber"); - } - - @Test - public void shouldNotSubscribeTwiceWhenCanceled() { - CountDownLatch latch = new CountDownLatch(1); - StepVerifier.create( - Flux.just(1L, 2L) - .doOnComplete( - () -> { - try { - latch.await(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - }) - .hide() - .publishOn(Schedulers.parallel()) - .cancelOn(NoOpsScheduler.INSTANCE) - .doOnCancel(latch::countDown) - .transform(flux -> new FluxSwitchOnFirst<>(flux, (s, f) -> f, false)) - .doOnSubscribe( - s -> Schedulers.single().schedule(s::cancel, 10, TimeUnit.MILLISECONDS))) - .expectSubscription() - .expectNext(2L) - .expectNoEvent(Duration.ofMillis(200)) - .thenCancel() - .verifyThenAssertThat() - .hasNotDroppedErrors(); - } - - @Test - public void shouldNotSubscribeTwiceConditionalWhenCanceled() { - CountDownLatch latch = new CountDownLatch(1); - StepVerifier.create( - Flux.just(1L, 2L) - .doOnComplete( - () -> { - try { - latch.await(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - }) - .hide() - .publishOn(Schedulers.parallel()) - .cancelOn(NoOpsScheduler.INSTANCE) - .doOnCancel(latch::countDown) - .transform(flux -> new FluxSwitchOnFirst<>(flux, (s, f) -> f, false)) - .filter(e -> true) - .doOnSubscribe( - s -> Schedulers.single().schedule(s::cancel, 10, TimeUnit.MILLISECONDS))) - .expectSubscription() - .expectNext(2L) - .expectNoEvent(Duration.ofMillis(200)) - .thenCancel() - .verifyThenAssertThat() - .hasNotDroppedErrors(); - } - - @Test - public void shouldSendOnErrorSignalConditional() { - @SuppressWarnings("unchecked") - Signal[] first = new Signal[1]; - - RuntimeException error = new RuntimeException(); - StepVerifier.create( - Flux.error(error) - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (s, f) -> { - first[0] = s; - - return f; - }, - false)) - .filter(e -> true)) - .expectSubscription() - .expectError(RuntimeException.class) - .verify(); - - Assertions.assertThat(first).containsExactly(Signal.error(error)); - } - - @Test - public void shouldSendOnNextSignalConditional() { - @SuppressWarnings("unchecked") - Signal[] first = new Signal[1]; - - StepVerifier.create( - Flux.just(1L) - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (s, f) -> { - first[0] = s; - - return f; - }, - false)) - .filter(e -> true)) - .expectSubscription() - .expectComplete() - .verify(); - - Assertions.assertThat((long) first[0].get()).isEqualTo(1L); - } - - @Test - public void shouldSendOnErrorSignalWithDelaySubscription() { - @SuppressWarnings("unchecked") - Signal[] first = new Signal[1]; - - RuntimeException error = new RuntimeException(); - StepVerifier.create( - Flux.error(error) - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (s, f) -> { - first[0] = s; - - return f.delaySubscription(Duration.ofMillis(100)); - }, - false))) - .expectSubscription() - .expectError(RuntimeException.class) - .verify(); - - Assertions.assertThat(first).containsExactly(Signal.error(error)); - } - - @Test - public void shouldSendOnCompleteSignalWithDelaySubscription() { - @SuppressWarnings("unchecked") - Signal[] first = new Signal[1]; - - StepVerifier.create( - Flux.empty() - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (s, f) -> { - first[0] = s; - - return f.delaySubscription(Duration.ofMillis(100)); - }, - false))) - .expectSubscription() - .expectComplete() - .verify(); - - Assertions.assertThat(first).containsExactly(Signal.complete()); - } - - @Test - public void shouldSendOnErrorSignal() { - @SuppressWarnings("unchecked") - Signal[] first = new Signal[1]; - - RuntimeException error = new RuntimeException(); - StepVerifier.create( - Flux.error(error) - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (s, f) -> { - first[0] = s; - - return f; - }, - false))) - .expectSubscription() - .expectError(RuntimeException.class) - .verify(); - - Assertions.assertThat(first).containsExactly(Signal.error(error)); - } - - @Test - public void shouldSendOnNextSignal() { - @SuppressWarnings("unchecked") - Signal[] first = new Signal[1]; - - StepVerifier.create( - Flux.just(1L) - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (s, f) -> { - first[0] = s; - - return f; - }, - false))) - .expectSubscription() - .expectComplete() - .verify(); - - Assertions.assertThat((long) first[0].get()).isEqualTo(1L); - } - - @Test - public void shouldSendOnNextAsyncSignal() { - @SuppressWarnings("unchecked") - Signal[] first = new Signal[1]; - - StepVerifier.create( - Flux.just(1L) - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (s, f) -> { - first[0] = s; - - return f.subscribeOn(Schedulers.elastic()); - }, - false))) - .expectSubscription() - .expectComplete() - .verify(); - - Assertions.assertThat((long) first[0].get()).isEqualTo(1L); - } - - @Test - public void shouldSendOnNextAsyncSignalConditional() { - @SuppressWarnings("unchecked") - Signal[] first = new Signal[1]; - - StepVerifier.create( - Flux.just(1L) - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (s, f) -> { - first[0] = s; - - return f.subscribeOn(Schedulers.elastic()); - }, - false)) - .filter(p -> true)) - .expectSubscription() - .expectComplete() - .verify(); - - Assertions.assertThat((long) first[0].get()).isEqualTo(1L); - } - - @Test - public void shouldNeverSendIncorrectRequestSizeToUpstream() throws InterruptedException { - TestPublisher publisher = TestPublisher.createCold(); - AtomicLong capture = new AtomicLong(-1); - ArrayList requested = new ArrayList<>(); - CountDownLatch latch = new CountDownLatch(1); - Flux switchTransformed = - publisher - .flux() - .doOnRequest(requested::add) - .transform( - flux -> new FluxSwitchOnFirst<>(flux, (first, innerFlux) -> innerFlux, false)); - - publisher.next(1L); - publisher.complete(); - - switchTransformed.subscribe(capture::set, __ -> {}, latch::countDown, s -> s.request(1)); - - latch.await(5, TimeUnit.SECONDS); - - Assertions.assertThat(capture.get()).isEqualTo(-1); - Assertions.assertThat(requested).containsExactly(1L, 1L); - } - - @Test - public void shouldNeverSendIncorrectRequestSizeToUpstreamConditional() - throws InterruptedException { - TestPublisher publisher = TestPublisher.createCold(); - AtomicLong capture = new AtomicLong(-1); - ArrayList requested = new ArrayList<>(); - CountDownLatch latch = new CountDownLatch(1); - Flux switchTransformed = - publisher - .flux() - .doOnRequest(e1 -> requested.add(e1)) - .transform( - flux -> new FluxSwitchOnFirst<>(flux, (first, innerFlux) -> innerFlux, false)) - .filter(e -> true); - - publisher.next(1L); - publisher.complete(); - - switchTransformed.subscribe(capture::set, __ -> {}, latch::countDown, s -> s.request(1)); - - latch.await(5, TimeUnit.SECONDS); - - Assertions.assertThat(capture.get()).isEqualTo(-1L); - Assertions.assertThat(requested).containsExactly(1L, 1L); - } - - @Test - public void shouldBeRequestedOneFromUpstreamTwiceInCaseOfConditional() - throws InterruptedException { - TestPublisher publisher = TestPublisher.createCold(); - ArrayList capture = new ArrayList<>(); - ArrayList requested = new ArrayList<>(); - CountDownLatch latch = new CountDownLatch(1); - Flux switchTransformed = - publisher - .flux() - .doOnRequest(requested::add) - .transform( - flux -> new FluxSwitchOnFirst<>(flux, (first, innerFlux) -> innerFlux, false)) - .filter(e -> false); - - publisher.next(1L); - publisher.complete(); - - switchTransformed.subscribe(capture::add, __ -> {}, latch::countDown, s -> s.request(1)); - - latch.await(5, TimeUnit.SECONDS); - - Assertions.assertThat(capture).isEmpty(); - Assertions.assertThat(requested).containsExactly(1L, 1L); - } - - @Test - public void shouldBeRequestedExactlyOneAndThenLongMaxValue() throws InterruptedException { - TestPublisher publisher = TestPublisher.createCold(); - ArrayList capture = new ArrayList<>(); - ArrayList requested = new ArrayList<>(); - CountDownLatch latch = new CountDownLatch(1); - Flux switchTransformed = - publisher - .flux() - .doOnRequest(requested::add) - .transform( - flux -> new FluxSwitchOnFirst<>(flux, (first, innerFlux) -> innerFlux, false)); - - publisher.next(1L); - publisher.complete(); - - switchTransformed.subscribe(capture::add, __ -> {}, latch::countDown); - - latch.await(5, TimeUnit.SECONDS); - - Assertions.assertThat(capture).isEmpty(); - Assertions.assertThat(requested).containsExactly(1L, Long.MAX_VALUE); - } - - @Test - public void shouldBeRequestedExactlyOneAndThenLongMaxValueConditional() - throws InterruptedException { - TestPublisher publisher = TestPublisher.createCold(); - ArrayList capture = new ArrayList<>(); - ArrayList requested = new ArrayList<>(); - CountDownLatch latch = new CountDownLatch(1); - Flux switchTransformed = - publisher - .flux() - .doOnRequest(requested::add) - .transform( - flux -> new FluxSwitchOnFirst<>(flux, (first, innerFlux) -> innerFlux, false)); - - publisher.next(1L); - publisher.complete(); - - switchTransformed.subscribe(capture::add, __ -> {}, latch::countDown); - - latch.await(5, TimeUnit.SECONDS); - - Assertions.assertThat(capture).isEmpty(); - Assertions.assertThat(requested).containsExactly(1L, Long.MAX_VALUE); - } - - @Test - public void shouldReturnCorrectContextOnEmptySource() { - @SuppressWarnings("unchecked") - Signal[] first = new Signal[1]; - - Flux switchTransformed = - Flux.empty() - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (f, innerFlux) -> { - first[0] = f; - return innerFlux; - }, - false)) - .subscriberContext(Context.of("a", "c")) - .subscriberContext(Context.of("c", "d")); - - StepVerifier.create(switchTransformed, 0) - .expectSubscription() - .thenRequest(1) - .expectAccessibleContext() - .contains("a", "c") - .contains("c", "d") - .then() - .expectComplete() - .verify(); - - Assertions.assertThat(first) - .containsExactly(Signal.complete(Context.of("a", "c").put("c", "d"))); - } - - @Test - public void shouldNotFailOnIncorrectPublisherBehavior() { - TestPublisher publisher = - TestPublisher.createNoncompliant(TestPublisher.Violation.CLEANUP_ON_TERMINATE); - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (first, innerFlux) -> innerFlux.subscriberContext(Context.of("a", "b")), - false)); - - StepVerifier.create( - new Flux() { - @Override - public void subscribe(CoreSubscriber actual) { - switchTransformed.subscribe(actual); - publisher.next(1L); - } - }, - 0) - .thenRequest(1) - .then(() -> publisher.next(2L)) - .expectNext(2L) - .then(() -> publisher.error(new RuntimeException())) - .then(() -> publisher.error(new RuntimeException())) - .then(() -> publisher.error(new RuntimeException())) - .then(() -> publisher.error(new RuntimeException())) - .expectError() - .verifyThenAssertThat() - .hasDroppedErrors(3) - .tookLessThan(Duration.ofSeconds(10)); - - publisher.assertWasRequested(); - publisher.assertNoRequestOverflow(); - } - - @Test - public void shouldBeAbleToAccessUpstreamContext() { - TestPublisher publisher = TestPublisher.createCold(); - - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (first, innerFlux) -> - innerFlux.map(String::valueOf).subscriberContext(Context.of("a", "b")), - false)) - .subscriberContext(Context.of("a", "c")) - .subscriberContext(Context.of("c", "d")); - - publisher.next(1L); - publisher.next(2L); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectNext("2") - .thenRequest(1) - .then(() -> publisher.next(3L)) - .expectNext("3") - .expectAccessibleContext() - .contains("a", "b") - .contains("c", "d") - .then() - .then(publisher::complete) - .expectComplete() - .verify(Duration.ofSeconds(10)); - - publisher.assertWasRequested(); - publisher.assertNoRequestOverflow(); - } - - @Test - public void shouldNotHangWhenOneElementUpstream() { - TestPublisher publisher = TestPublisher.createCold(); - - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (first, innerFlux) -> - innerFlux.map(String::valueOf).subscriberContext(Context.of("a", "b")), - false)) - .subscriberContext(Context.of("a", "c")) - .subscriberContext(Context.of("c", "d")); - - publisher.next(1L); - publisher.complete(); - - StepVerifier.create(switchTransformed, 0).expectComplete().verify(Duration.ofSeconds(10)); - - publisher.assertWasRequested(); - publisher.assertNoRequestOverflow(); - } - - @Test - public void backpressureTest() { - TestPublisher publisher = TestPublisher.createCold(); - AtomicLong requested = new AtomicLong(); - - Flux switchTransformed = - publisher - .flux() - .doOnRequest(requested::addAndGet) - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf), false)); - - publisher.next(1L); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .then(() -> publisher.next(2L)) - .expectNext("2") - .then(publisher::complete) - .expectComplete() - .verify(Duration.ofSeconds(10)); - - publisher.assertWasRequested(); - publisher.assertNoRequestOverflow(); - - Assertions.assertThat(requested.get()).isEqualTo(2L); - } - - @Test - public void backpressureConditionalTest() { - Flux publisher = Flux.range(0, 10000); - AtomicLong requested = new AtomicLong(); - - Flux switchTransformed = - publisher - .doOnRequest(requested::addAndGet) - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf), false)) - .filter(e -> false); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectComplete() - .verify(Duration.ofSeconds(10)); - - Assertions.assertThat(requested.get()).isEqualTo(2L); - } - - @Test - public void backpressureHiddenConditionalTest() { - Flux publisher = Flux.range(0, 10000); - AtomicLong requested = new AtomicLong(); - - Flux switchTransformed = - publisher - .doOnRequest(requested::addAndGet) - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf).hide(), false)) - .filter(e -> false); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectComplete() - .verify(Duration.ofSeconds(10)); - - Assertions.assertThat(requested.get()).isEqualTo(10001L); - } - - @Test - public void backpressureDrawbackOnConditionalInTransformTest() { - Flux publisher = Flux.range(0, 10000); - AtomicLong requested = new AtomicLong(); - - Flux switchTransformed = - publisher - .doOnRequest(requested::addAndGet) - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (first, innerFlux) -> innerFlux.map(String::valueOf).filter(e -> false), - false)); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectComplete() - .verify(Duration.ofSeconds(10)); - - Assertions.assertThat(requested.get()).isEqualTo(10001L); - } - - @Test - public void shouldErrorOnOverflowTest() { - TestPublisher publisher = TestPublisher.createCold(); - - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf), false)); - - publisher.next(1L); - publisher.next(2L); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectNext("2") - .then(() -> publisher.next(2L)) - .expectErrorSatisfies( - t -> - Assertions.assertThat(t) - .isInstanceOf(IllegalStateException.class) - .hasMessage("Can't deliver value due to lack of requests")) - .verify(Duration.ofSeconds(10)); - - publisher.assertWasRequested(); - publisher.assertNoRequestOverflow(); - } - - @Test - public void shouldPropagateonCompleteCorrectly() { - Flux switchTransformed = - Flux.empty() - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf), false)); - - StepVerifier.create(switchTransformed).expectComplete().verify(Duration.ofSeconds(10)); - } - - @Test - public void shouldPropagateOnCompleteWithMergedElementsCorrectly() { - Flux switchTransformed = - Flux.empty() - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (first, innerFlux) -> - innerFlux.map(String::valueOf).mergeWith(Flux.just("1", "2", "3")), - false)); - - StepVerifier.create(switchTransformed) - .expectNext("1", "2", "3") - .expectComplete() - .verify(Duration.ofSeconds(10)); - } - - @Test - public void shouldPropagateErrorCorrectly() { - Flux switchTransformed = - Flux.error(new RuntimeException("hello")) - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf), true)); - - StepVerifier.create(switchTransformed) - .expectErrorMessage("hello") - .verify(Duration.ofSeconds(10)); - } - - @Test - public void shouldBeAbleToBeCancelledProperly() { - TestPublisher publisher = TestPublisher.createCold(); - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf), false)); - - publisher.next(1); - - StepVerifier.create(switchTransformed, 0).thenCancel().verify(Duration.ofSeconds(10)); - - publisher.assertCancelled(); - publisher.assertWasRequested(); - } - - @Test - public void shouldBeAbleToBeCancelledProperly2() { - TestPublisher publisher = TestPublisher.createCold(); - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf).take(1), false)); - - publisher.next(1); - publisher.next(2); - publisher.next(3); - publisher.next(4); - - StepVerifier.create(switchTransformed, 1) - .expectNext("2") - .expectComplete() - .verify(Duration.ofSeconds(10)); - - publisher.assertCancelled(); - publisher.assertWasRequested(); - } - - @Test - public void shouldBeAbleToBeCancelledProperly3() { - TestPublisher publisher = TestPublisher.createCold(); - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf), false)) - .take(1); - - publisher.next(1); - publisher.next(2); - publisher.next(3); - publisher.next(4); - - StepVerifier.create(switchTransformed, 1) - .expectNext("2") - .expectComplete() - .verify(Duration.ofSeconds(10)); - - publisher.assertCancelled(); - publisher.assertWasRequested(); - } - - @Test - public void shouldReturnNormallyIfExceptionIsThrownOnNextDuringSwitching() { - @SuppressWarnings("unchecked") - Signal[] first = new Signal[1]; - - Optional expectedCause = Optional.of(1L); - - StepVerifier.create( - Flux.just(1L) - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (s, f) -> { - first[0] = s; - throw new NullPointerException(); - }, - false))) - .expectSubscription() - .expectError(NullPointerException.class) - .verifyThenAssertThat() - .hasOperatorErrorsSatisfying( - c -> - Assertions.assertThat(c) - .hasOnlyOneElementSatisfying( - t -> { - Assertions.assertThat(t.getT1()) - .containsInstanceOf(NullPointerException.class); - Assertions.assertThat(t.getT2()).isEqualTo(expectedCause); - })); - - Assertions.assertThat((long) first[0].get()).isEqualTo(1L); - } - - @Test - public void shouldReturnNormallyIfExceptionIsThrownOnErrorDuringSwitching() { - @SuppressWarnings("unchecked") - Signal[] first = new Signal[1]; - - NullPointerException npe = new NullPointerException(); - RuntimeException error = new RuntimeException(); - StepVerifier.create( - Flux.error(error) - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (s, f) -> { - first[0] = s; - throw npe; - }, - false))) - .expectSubscription() - .verifyError(NullPointerException.class); - - Assertions.assertThat(first).containsExactly(Signal.error(error)); - } - - @Test - public void shouldReturnNormallyIfExceptionIsThrownOnCompleteDuringSwitching() { - @SuppressWarnings("unchecked") - Signal[] first = new Signal[1]; - - StepVerifier.create( - Flux.empty() - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (s, f) -> { - first[0] = s; - throw new NullPointerException(); - }, - false))) - .expectSubscription() - .expectError(NullPointerException.class) - .verifyThenAssertThat() - .hasOperatorErrorMatching( - t -> { - Assertions.assertThat(t).isInstanceOf(NullPointerException.class); - return true; - }); - - Assertions.assertThat(first).containsExactly(Signal.complete()); - } - - @Test - public void sourceSubscribedOnce() { - AtomicInteger subCount = new AtomicInteger(); - Flux source = - Flux.range(1, 10).hide().doOnSubscribe(subscription -> subCount.incrementAndGet()); - - StepVerifier.create( - source.transform( - flux -> - new FluxSwitchOnFirst<>( - flux, (s, f) -> f.filter(v -> v % 2 == s.get()), false))) - .expectNext(3, 5, 7, 9) - .verifyComplete(); - - Assertions.assertThat(subCount).hasValue(1); - } - - @Test - public void checkHotSource() { - ReplayProcessor processor = ReplayProcessor.create(1); - - processor.onNext(1L); - processor.onNext(2L); - processor.onNext(3L); - - StepVerifier.create( - processor.transform( - flux -> - new FluxSwitchOnFirst<>( - flux, (s, f) -> f.filter(v -> v % s.get() == 0), false))) - .then( - () -> { - processor.onNext(4L); - processor.onNext(5L); - processor.onNext(6L); - processor.onNext(7L); - processor.onNext(8L); - processor.onNext(9L); - processor.onComplete(); - }) - .expectNext(6L, 9L) - .verifyComplete(); - } - - @Test - public void shouldCancelSourceOnUnrelatedPublisherComplete() { - EmitterProcessor testPublisher = EmitterProcessor.create(); - - testPublisher.onNext(1L); - - StepVerifier.create( - testPublisher.transform( - flux -> new FluxSwitchOnFirst<>(flux, (s, f) -> Flux.empty(), true))) - .expectSubscription() - .verifyComplete(); - - Assertions.assertThat(testPublisher.isCancelled()).isTrue(); - } - - @Test - public void shouldCancelSourceOnUnrelatedPublisherError() { - EmitterProcessor testPublisher = EmitterProcessor.create(); - - testPublisher.onNext(1L); - - StepVerifier.create( - testPublisher.transform( - flux -> - new FluxSwitchOnFirst<>( - flux, (s, f) -> Flux.error(new RuntimeException("test")), false))) - .expectSubscription() - .verifyErrorSatisfies( - t -> - Assertions.assertThat(t) - .hasMessage("test") - .isExactlyInstanceOf(RuntimeException.class)); - - Assertions.assertThat(testPublisher.isCancelled()).isTrue(); - } - - @Test - public void shouldCancelSourceOnUnrelatedPublisherCancel() { - TestPublisher testPublisher = TestPublisher.create(); - - StepVerifier.create( - testPublisher - .flux() - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, (s, f) -> Flux.error(new RuntimeException("test")), false))) - .expectSubscription() - .thenCancel() - .verify(); - - Assertions.assertThat(testPublisher.wasCancelled()).isTrue(); - } - - @Test - public void shouldCancelUpstreamBeforeFirst() { - EmitterProcessor testPublisher = EmitterProcessor.create(); - - StepVerifier.create( - testPublisher.transform( - flux -> - new FluxSwitchOnFirst<>( - flux, (s, f) -> Flux.error(new RuntimeException("test")), false))) - .thenAwait(Duration.ofMillis(50)) - .thenCancel() - .verify(Duration.ofSeconds(2)); - - Assertions.assertThat(testPublisher.isCancelled()).isTrue(); - } - - @Test - public void shouldContinueWorkingRegardlessTerminalOnDownstream() { - TestPublisher testPublisher = TestPublisher.create(); - - Flux[] intercepted = new Flux[1]; - - StepVerifier.create( - testPublisher - .flux() - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (s, f) -> { - intercepted[0] = f; - return Flux.just(2L); - }, - false))) - .expectSubscription() - .then(() -> testPublisher.next(1L)) - .expectNext(2L) - .expectComplete() - .verify(Duration.ofSeconds(2)); - - Assertions.assertThat(testPublisher.wasCancelled()).isFalse(); - - StepVerifier.create(intercepted[0]) - .expectSubscription() - .then(testPublisher::complete) - .expectComplete() - .verify(Duration.ofSeconds(1)); - } - - @Test - public void shouldCancelSourceOnOnDownstreamTerminal() { - TestPublisher testPublisher = TestPublisher.create(); - - StepVerifier.create( - testPublisher - .flux() - .transform(flux -> new FluxSwitchOnFirst<>(flux, (s, f) -> Flux.just(1L), true))) - .expectSubscription() - .then(() -> testPublisher.next(1L)) - .expectNext(1L) - .expectComplete() - .verify(Duration.ofSeconds(2)); - - Assertions.assertThat(testPublisher.wasCancelled()).isTrue(); - } - - @Test - public void racingTest() { - for (int i = 0; i < 1000; i++) { - CoreSubscriber[] subscribers = new CoreSubscriber[1]; - Subscription[] downstreamSubscriptions = new Subscription[1]; - Subscription[] innerSubscriptions = new Subscription[1]; - - AtomicLong requested = new AtomicLong(); - - Flux.just(2) - .doOnRequest(requested::addAndGet) - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (s, f) -> - new Flux() { - - @Override - public void subscribe(CoreSubscriber actual) { - subscribers[0] = actual; - f.subscribe( - actual::onNext, - actual::onError, - actual::onComplete, - (s) -> innerSubscriptions[0] = s); - } - }, - false)) - .subscribe(__ -> {}, __ -> {}, () -> {}, s -> downstreamSubscriptions[0] = s); - - CoreSubscriber subscriber = subscribers[0]; - Subscription downstreamSubscription = downstreamSubscriptions[0]; - Subscription innerSubscription = innerSubscriptions[0]; - innerSubscription.request(1); - - RaceTestUtils.race( - () -> subscriber.onSubscribe(innerSubscription), () -> downstreamSubscription.request(1)); - - Assertions.assertThat(requested.get()).isEqualTo(3); - } - } - - @Test - public void racingConditionalTest() { - for (int i = 0; i < 1000; i++) { - CoreSubscriber[] subscribers = new CoreSubscriber[1]; - Subscription[] downstreamSubscriptions = new Subscription[1]; - Subscription[] innerSubscriptions = new Subscription[1]; - - AtomicLong requested = new AtomicLong(); - - Flux.just(2) - .doOnRequest(requested::addAndGet) - .transform( - flux -> - new FluxSwitchOnFirst<>( - flux, - (s, f) -> - new Flux() { - - @Override - public void subscribe(CoreSubscriber actual) { - subscribers[0] = actual; - f.subscribe( - new Fuseable.ConditionalSubscriber() { - @Override - public boolean tryOnNext(Integer integer) { - return ((Fuseable.ConditionalSubscriber) actual) - .tryOnNext(integer); - } - - @Override - public void onSubscribe(Subscription s) { - innerSubscriptions[0] = s; - } - - @Override - public void onNext(Integer integer) { - actual.onNext(integer); - } - - @Override - public void onError(Throwable throwable) { - actual.onError(throwable); - } - - @Override - public void onComplete() { - actual.onComplete(); - } - }); - } - }, - false)) - .filter(__ -> true) - .subscribe(__ -> {}, __ -> {}, () -> {}, s -> downstreamSubscriptions[0] = s); - - CoreSubscriber subscriber = subscribers[0]; - Subscription downstreamSubscription = downstreamSubscriptions[0]; - Subscription innerSubscription = innerSubscriptions[0]; - innerSubscription.request(1); - - RaceTestUtils.race( - () -> subscriber.onSubscribe(innerSubscription), () -> downstreamSubscription.request(1)); - - Assertions.assertThat(requested.get()).isEqualTo(3); - } - } - - private static final class NoOpsScheduler implements Scheduler { - - static final NoOpsScheduler INSTANCE = new NoOpsScheduler(); - - private NoOpsScheduler() {} - - @Override - public Disposable schedule(Runnable task) { - return Disposables.composite(); - } - - @Override - public Worker createWorker() { - return NoOpsWorker.INSTANCE; - } - - static final class NoOpsWorker implements Worker { - - static final NoOpsWorker INSTANCE = new NoOpsWorker(); - - @Override - public Disposable schedule(Runnable task) { - return Disposables.never(); - } - - @Override - public void dispose() {} - }; - } -} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java index ca115d281..fe8b6766f 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java @@ -36,7 +36,11 @@ public static void main(String[] args) { RSocketFactory.connect() .resume() .resumeStrategy( - () -> new VerboseResumeStrategy(new PeriodicResumeStrategy(Duration.ofSeconds(1)))) + () -> { + System.out.println("created"); + return new VerboseResumeStrategy( + new PeriodicResumeStrategy(Duration.ofSeconds(1))); + }) .resumeSessionDuration(Duration.ofMinutes(5)) .transport(TcpClientTransport.create("localhost", 8001)) .start() From b7522b669c8dc5f2657d52f28a796364ab284e79 Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Thu, 9 Apr 2020 15:19:49 +0300 Subject: [PATCH 04/11] rollbacks unwanted changes Signed-off-by: Oleh Dokuka --- .../examples/transport/tcp/resume/ResumeFileTransfer.java | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java index fe8b6766f..ca115d281 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java @@ -36,11 +36,7 @@ public static void main(String[] args) { RSocketFactory.connect() .resume() .resumeStrategy( - () -> { - System.out.println("created"); - return new VerboseResumeStrategy( - new PeriodicResumeStrategy(Duration.ofSeconds(1))); - }) + () -> new VerboseResumeStrategy(new PeriodicResumeStrategy(Duration.ofSeconds(1)))) .resumeSessionDuration(Duration.ofMinutes(5)) .transport(TcpClientTransport.create("localhost", 8001)) .start() From 43e3e39a313ef722bc5c1fff7d61ada40841ce10 Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Sun, 5 Apr 2020 13:36:59 +0300 Subject: [PATCH 05/11] initial Signed-off-by: Oleh Dokuka --- .../io/rsocket/core/RSocketRequester.java | 37 +++++++++++++++++++ .../fragmentation/FragmentationUtils.java | 33 +++++++++++++++++ .../fragmentation/FrameFragmenter.java | 1 + .../core/RSocketRequesterSubscribersTest.java | 1 + .../src/test/resources/logback-test.xml | 4 +- 5 files changed, 74 insertions(+), 2 deletions(-) create mode 100644 rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationUtils.java diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index 6c26361a2..3c7ebe458 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -88,6 +88,7 @@ class RSocketRequester implements RSocket { private final IntObjectMap senders; private final IntObjectMap> receivers; private final UnboundedProcessor sendProcessor; + private final int mtu; private final RequesterLeaseHandler leaseHandler; private final ByteBufAllocator allocator; private final KeepAliveFramesAcceptor keepAliveFramesAcceptor; @@ -99,6 +100,7 @@ class RSocketRequester implements RSocket { PayloadDecoder payloadDecoder, Consumer errorConsumer, StreamIdSupplier streamIdSupplier, + int mtu, int keepAliveTickPeriod, int keepAliveAckTimeout, @Nullable KeepAliveHandler keepAliveHandler, @@ -108,6 +110,7 @@ class RSocketRequester implements RSocket { this.payloadDecoder = payloadDecoder; this.errorConsumer = errorConsumer; this.streamIdSupplier = streamIdSupplier; + this.mtu = mtu; this.leaseHandler = leaseHandler; this.senders = new SynchronizedIntObjectHashMap<>(); this.receivers = new SynchronizedIntObjectHashMap<>(); @@ -186,6 +189,11 @@ private Mono handleFireAndForget(Payload payload) { return Mono.error(err); } + if (!FragmentationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Mono.error(new IllegalArgumentException("Too big Payload size")); + } + final int streamId = streamIdSupplier.nextStreamId(receivers); return UnicastMonoEmpty.newInstance( @@ -210,6 +218,11 @@ private Mono handleRequestResponse(final Payload payload) { return Mono.error(err); } + if (!FragmentationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Mono.error(new IllegalArgumentException("Too big Payload size")); + } + int streamId = streamIdSupplier.nextStreamId(receivers); final UnboundedProcessor sendProcessor = this.sendProcessor; @@ -255,6 +268,11 @@ private Flux handleRequestStream(final Payload payload) { return Flux.error(err); } + if (!FragmentationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Flux.error(new IllegalArgumentException("Too big Payload size")); + } + int streamId = streamIdSupplier.nextStreamId(receivers); final UnboundedProcessor sendProcessor = this.sendProcessor; @@ -318,6 +336,13 @@ private Flux handleChannel(Flux request) { Payload payload = s.get(); if (payload != null) { return handleChannel(payload, flux); + if (!FragmentationUtils.isValid(mtu, payload)) { + payload.release(); + final IllegalArgumentException t = new IllegalArgumentException("Too big Payload size"); + errorConsumer.accept(t); + return Mono.error(t); + } + return handleChannel(payload, flux.skip(1)); } else { return flux; } @@ -343,11 +368,23 @@ protected void hookOnSubscribe(Subscription subscription) { @Override protected void hookOnNext(Payload payload) { + if (first) { // need to skip first since we have already sent it first = false; return; } + + if (!FragmentationUtils.isValid(mtu, payload)) { + payload.release(); + cancel(); + final IllegalArgumentException t = new IllegalArgumentException("Too big Payload size"); + errorConsumer.accept(t); + // no need to send any errors. + sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); + receiver.onError(t); + return ; + } final ByteBuf frame = PayloadFrameFlyweight.encode(allocator, streamId, false, false, true, payload); diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationUtils.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationUtils.java new file mode 100644 index 000000000..7cf694901 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationUtils.java @@ -0,0 +1,33 @@ +package io.rsocket.fragmentation; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameLengthFlyweight; + +public final class FragmentationUtils { + public static boolean isValid(int mtu, Payload payload) { + return payload.hasMetadata() ? isValid(mtu, payload.data(), payload.metadata()) : isValid(mtu, payload.metadata()); + } + + public static boolean isValid(int mtu, ByteBuf data) { + return mtu > 0 + || (((FrameHeaderFlyweight.size() + + data.readableBytes() + + FrameLengthFlyweight.FRAME_LENGTH_SIZE) + & ~FrameLengthFlyweight.FRAME_LENGTH_MASK) + == 0); + } + + public static boolean isValid(int mtu, ByteBuf data, ByteBuf metadata) { + return mtu > 0 + || (((FrameHeaderFlyweight.size() + + FrameLengthFlyweight.FRAME_LENGTH_SIZE + + FrameHeaderFlyweight.size() + + data.readableBytes() + + metadata.readableBytes()) + & ~FrameLengthFlyweight.FRAME_LENGTH_MASK) + == 0); + } + +} diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java index d634f7374..a7f2d8ea8 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java @@ -20,6 +20,7 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; import io.rsocket.frame.*; import java.util.function.Consumer; import org.reactivestreams.Publisher; diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java index 8a2e114cc..8380290f2 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java @@ -67,6 +67,7 @@ void setUp() { StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, RequesterLeaseHandler.None); } diff --git a/rsocket-transport-netty/src/test/resources/logback-test.xml b/rsocket-transport-netty/src/test/resources/logback-test.xml index f9dec2bbe..a4442c6c4 100644 --- a/rsocket-transport-netty/src/test/resources/logback-test.xml +++ b/rsocket-transport-netty/src/test/resources/logback-test.xml @@ -23,8 +23,8 @@ - - + + From d5c67f2791b85fbc638cfd44e4557633f8fd1a36 Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Sun, 5 Apr 2020 14:33:47 +0300 Subject: [PATCH 06/11] provides payload length validation in case fragmentation is disabled Signed-off-by: Oleh Dokuka --- .../io/rsocket/core/RSocketRequester.java | 39 +++++++++--------- .../io/rsocket/core/RSocketResponder.java | 25 ++++++++++- .../fragmentation/FragmentationUtils.java | 41 ++++++++++--------- .../fragmentation/FrameFragmenter.java | 13 ++++-- 4 files changed, 74 insertions(+), 44 deletions(-) diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index 3c7ebe458..13fa27400 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -335,14 +335,14 @@ private Flux handleChannel(Flux request) { (s, flux) -> { Payload payload = s.get(); if (payload != null) { + if (!FragmentationUtils.isValid(mtu, payload)) { + payload.release(); + final IllegalArgumentException t = + new IllegalArgumentException("Too big Payload size"); + errorConsumer.accept(t); + return Mono.error(t); + } return handleChannel(payload, flux); - if (!FragmentationUtils.isValid(mtu, payload)) { - payload.release(); - final IllegalArgumentException t = new IllegalArgumentException("Too big Payload size"); - errorConsumer.accept(t); - return Mono.error(t); - } - return handleChannel(payload, flux.skip(1)); } else { return flux; } @@ -368,23 +368,22 @@ protected void hookOnSubscribe(Subscription subscription) { @Override protected void hookOnNext(Payload payload) { - - if (first) { +if (first) { // need to skip first since we have already sent it first = false; return; } - - if (!FragmentationUtils.isValid(mtu, payload)) { - payload.release(); - cancel(); - final IllegalArgumentException t = new IllegalArgumentException("Too big Payload size"); - errorConsumer.accept(t); - // no need to send any errors. - sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); - receiver.onError(t); - return ; - } +if (!FragmentationUtils.isValid(mtu, payload)) { + payload.release(); + cancel(); + final IllegalArgumentException t = + new IllegalArgumentException("Too big Payload size"); + errorConsumer.accept(t); + // no need to send any errors. + sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); + receiver.onError(t); + return; + } final ByteBuf frame = PayloadFrameFlyweight.encode(allocator, streamId, false, false, true, payload); diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java index de6e8ad23..36e1a8f7f 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -25,6 +25,7 @@ import io.rsocket.RSocket; import io.rsocket.ResponderRSocket; import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.fragmentation.FragmentationUtils; import io.rsocket.frame.*; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.SynchronizedIntObjectHashMap; @@ -51,6 +52,8 @@ class RSocketResponder implements ResponderRSocket { private final Consumer errorConsumer; private final ResponderLeaseHandler leaseHandler; + private final int mtu; + private final IntObjectMap sendingSubscriptions; private final IntObjectMap> channelProcessors; @@ -63,9 +66,11 @@ class RSocketResponder implements ResponderRSocket { RSocket requestHandler, PayloadDecoder payloadDecoder, Consumer errorConsumer, - ResponderLeaseHandler leaseHandler) { + ResponderLeaseHandler leaseHandler, + int mtu) { this.allocator = allocator; this.connection = connection; + this.mtu = mtu; this.requestHandler = requestHandler; this.responderRSocket = @@ -371,6 +376,15 @@ protected void hookOnNext(Payload payload) { isEmpty = false; } + if (!FragmentationUtils.isValid(mtu, payload)) { + payload.release(); + cancel(); + final IllegalArgumentException t = + new IllegalArgumentException("Too big Payload size"); + handleError(streamId, t); + return; + } + ByteBuf byteBuf; try { byteBuf = PayloadFrameFlyweight.encodeNextComplete(allocator, streamId, payload); @@ -417,6 +431,15 @@ protected void hookOnSubscribe(Subscription s) { @Override protected void hookOnNext(Payload payload) { + if (!FragmentationUtils.isValid(mtu, payload)) { + payload.release(); + cancel(); + final IllegalArgumentException t = + new IllegalArgumentException("Too big Payload size"); + handleError(streamId, t); + return; + } + ByteBuf byteBuf; try { byteBuf = PayloadFrameFlyweight.encodeNext(allocator, streamId, payload); diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationUtils.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationUtils.java index 7cf694901..8868c82d7 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationUtils.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationUtils.java @@ -6,28 +6,29 @@ import io.rsocket.frame.FrameLengthFlyweight; public final class FragmentationUtils { - public static boolean isValid(int mtu, Payload payload) { - return payload.hasMetadata() ? isValid(mtu, payload.data(), payload.metadata()) : isValid(mtu, payload.metadata()); - } + public static boolean isValid(int mtu, Payload payload) { + return payload.hasMetadata() + ? isValid(mtu, payload.data(), payload.metadata()) + : isValid(mtu, payload.metadata()); + } - public static boolean isValid(int mtu, ByteBuf data) { - return mtu > 0 - || (((FrameHeaderFlyweight.size() - + data.readableBytes() - + FrameLengthFlyweight.FRAME_LENGTH_SIZE) + public static boolean isValid(int mtu, ByteBuf data) { + return mtu > 0 + || (((FrameHeaderFlyweight.size() + + data.readableBytes() + + FrameLengthFlyweight.FRAME_LENGTH_SIZE) & ~FrameLengthFlyweight.FRAME_LENGTH_MASK) - == 0); - } + == 0); + } - public static boolean isValid(int mtu, ByteBuf data, ByteBuf metadata) { - return mtu > 0 - || (((FrameHeaderFlyweight.size() - + FrameLengthFlyweight.FRAME_LENGTH_SIZE - + FrameHeaderFlyweight.size() - + data.readableBytes() - + metadata.readableBytes()) + public static boolean isValid(int mtu, ByteBuf data, ByteBuf metadata) { + return mtu > 0 + || (((FrameHeaderFlyweight.size() + + FrameLengthFlyweight.FRAME_LENGTH_SIZE + + FrameHeaderFlyweight.size() + + data.readableBytes() + + metadata.readableBytes()) & ~FrameLengthFlyweight.FRAME_LENGTH_MASK) - == 0); - } - + == 0); + } } diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java index a7f2d8ea8..177089f1a 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java @@ -20,13 +20,20 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.netty.util.ReferenceCountUtil; -import io.rsocket.Payload; -import io.rsocket.frame.*; -import java.util.function.Consumer; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameLengthFlyweight; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameFlyweight; +import io.rsocket.frame.RequestChannelFrameFlyweight; +import io.rsocket.frame.RequestFireAndForgetFrameFlyweight; +import io.rsocket.frame.RequestResponseFrameFlyweight; +import io.rsocket.frame.RequestStreamFrameFlyweight; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.SynchronousSink; +import java.util.function.Consumer; + /** * The implementation of the RSocket fragmentation behavior. * From 7e93ab973a69916cf136804da8425d052d6cba9c Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Sun, 5 Apr 2020 14:35:13 +0300 Subject: [PATCH 07/11] provides payload validation Signed-off-by: Oleh Dokuka --- .../io/rsocket/core/RSocketRequester.java | 5 ++ .../fragmentation/FrameFragmenter.java | 3 +- .../java/io/rsocket/core/KeepAliveTest.java | 2 + .../io/rsocket/core/RSocketLeaseTest.java | 4 +- .../io/rsocket/core/RSocketRequesterTest.java | 85 +++++++++++++++++- .../io/rsocket/core/RSocketResponderTest.java | 59 ++++++++++++- .../java/io/rsocket/core/RSocketTest.java | 4 +- .../io/rsocket/core/SetupRejectionTest.java | 2 + .../fragmentation/FragmentationUtilsTest.java | 87 +++++++++++++++++++ .../src/test/resources/logback-test.xml | 4 +- 10 files changed, 245 insertions(+), 10 deletions(-) create mode 100644 rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationUtilsTest.java diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index 13fa27400..e2fbd5c14 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -470,6 +470,11 @@ private Mono handleMetadataPush(Payload payload) { return Mono.error(err); } + if (!FragmentationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Mono.error(new IllegalArgumentException("Too big Payload size")); + } + return UnicastMonoEmpty.newInstance( () -> { ByteBuf metadataPushFrame = diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java index 177089f1a..e59ece86f 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java @@ -28,12 +28,11 @@ import io.rsocket.frame.RequestFireAndForgetFrameFlyweight; import io.rsocket.frame.RequestResponseFrameFlyweight; import io.rsocket.frame.RequestStreamFrameFlyweight; +import java.util.function.Consumer; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.SynchronousSink; -import java.util.function.Consumer; - /** * The implementation of the RSocket fragmentation behavior. * diff --git a/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java index 6cb05dec1..10725238a 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java @@ -61,6 +61,7 @@ static RSocketState requester(int tickPeriod, int timeout) { DefaultPayload::create, errors, StreamIdSupplier.clientSupplier(), + 0, tickPeriod, timeout, new DefaultKeepAliveHandler(connection), @@ -86,6 +87,7 @@ static ResumableRSocketState resumableRequester(int tickPeriod, int timeout) { DefaultPayload::create, errors, StreamIdSupplier.clientSupplier(), + 0, tickPeriod, timeout, new ResumableKeepAliveHandler(resumableConnection), diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java index 3cbb3c5d7..0a7f7a196 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java @@ -94,6 +94,7 @@ void setUp() { StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, requesterLeaseHandler); @@ -111,7 +112,8 @@ void setUp() { mockRSocketHandler, payloadDecoder, err -> {}, - responderLeaseHandler); + responderLeaseHandler, + 0); } @Test diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java index b6dbf71de..7788c194a 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -17,9 +17,19 @@ package io.rsocket.core; import static io.rsocket.frame.FrameHeaderFlyweight.frameType; -import static io.rsocket.frame.FrameType.*; +import static io.rsocket.frame.FrameType.CANCEL; +import static io.rsocket.frame.FrameType.KEEPALIVE; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.*; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.verify; @@ -29,7 +39,15 @@ import io.rsocket.Payload; import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.exceptions.RejectedSetupException; -import io.rsocket.frame.*; +import io.rsocket.frame.CancelFrameFlyweight; +import io.rsocket.frame.ErrorFrameFlyweight; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameLengthFlyweight; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameFlyweight; +import io.rsocket.frame.RequestChannelFrameFlyweight; +import io.rsocket.frame.RequestNFrameFlyweight; +import io.rsocket.frame.RequestStreamFrameFlyweight; import io.rsocket.lease.RequesterLeaseHandler; import io.rsocket.test.util.TestSubscriber; import io.rsocket.util.DefaultPayload; @@ -39,7 +57,10 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.BiFunction; import java.util.stream.Collectors; +import java.util.stream.Stream; import org.assertj.core.api.Assertions; import org.junit.Rule; import org.junit.Test; @@ -51,6 +72,7 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; import reactor.core.publisher.UnicastProcessor; +import reactor.test.StepVerifier; public class RSocketRequesterTest { @@ -267,6 +289,62 @@ protected void hookOnSubscribe(Subscription subscription) {} Assertions.assertThat(iterator.hasNext()).isFalse(); } + @Test + public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation() { + prepareCalls() + .forEach( + generator -> { + byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + StepVerifier.create( + generator.apply(rule.socket, DefaultPayload.create(data, metadata))) + .expectSubscription() + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Too big Payload size")) + .verify(); + }); + } + + @Test + public void + shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentationForRequestChannelCase() { + byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + StepVerifier.create( + rule.socket.requestChannel( + Flux.just(EmptyPayload.INSTANCE, DefaultPayload.create(data, metadata)))) + .expectSubscription() + .then( + () -> + rule.connection.addToReceivedBuffer( + RequestNFrameFlyweight.encode( + ByteBufAllocator.DEFAULT, + rule.getStreamIdForRequestType(REQUEST_CHANNEL), + 1))) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Too big Payload size")) + .verify(); + } + + static Stream>> prepareCalls() { + return Stream.of( + RSocket::fireAndForget, + RSocket::requestResponse, + RSocket::requestStream, + (rSocket, payload) -> rSocket.requestChannel(Flux.just(payload)), + RSocket::metadataPush); + } + public int sendRequestResponse(Publisher response) { Subscriber sub = TestSubscriber.create(); response.subscribe(sub); @@ -290,6 +368,7 @@ protected RSocketRequester newRSocket() { StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, RequesterLeaseHandler.None); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java index 10157532a..ed9d2a89a 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java @@ -34,11 +34,15 @@ import io.rsocket.util.EmptyPayload; import java.util.Collection; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicBoolean; +import org.assertj.core.api.Assertions; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; +import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; public class RSocketResponderTest { @@ -110,6 +114,58 @@ public Mono requestResponse(Payload payload) { assertThat("Subscription not cancelled.", cancelled.get(), is(true)); } + @Test + public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation() { + final int streamId = 4; + final AtomicBoolean cancelled = new AtomicBoolean(); + byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + final AbstractRSocket acceptingSocket = + new AbstractRSocket() { + @Override + public Mono requestResponse(Payload p) { + return Mono.just(payload).doOnCancel(() -> cancelled.set(true)); + } + + @Override + public Flux requestStream(Payload p) { + return Flux.just(payload).doOnCancel(() -> cancelled.set(true)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.just(payload).doOnCancel(() -> cancelled.set(true)); + } + }; + rule.setAcceptingSocket(acceptingSocket); + + final Runnable[] runnables = { + () -> rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE), + () -> rule.sendRequest(streamId, FrameType.REQUEST_STREAM), + () -> rule.sendRequest(streamId, FrameType.REQUEST_CHANNEL) + }; + + for (Runnable runnable : runnables) { + runnable.run(); + Assertions.assertThat(rule.errors) + .first() + .isInstanceOf(IllegalArgumentException.class) + .hasToString("java.lang.IllegalArgumentException: Too big Payload size"); + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderFlyweight.frameType(bb) == FrameType.ERROR) + .matches(bb -> ErrorFrameFlyweight.dataUtf8(bb).contains("Too big Payload size")); + + assertThat("Subscription not cancelled.", cancelled.get(), is(true)); + rule.init(); + rule.setAcceptingSocket(acceptingSocket); + } + } + public static class ServerSocketRule extends AbstractSocketRule { private RSocket acceptingSocket; @@ -151,7 +207,8 @@ protected RSocketResponder newRSocket() { acceptingSocket, DefaultPayload::create, throwable -> errors.add(throwable), - ResponderLeaseHandler.None); + ResponderLeaseHandler.None, + 0); } private void sendRequest(int streamId, FrameType frameType) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java index b18fad890..edcc8971f 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java @@ -222,7 +222,8 @@ public Flux requestChannel(Publisher payloads) { requestAcceptor, DefaultPayload::create, throwable -> serverErrors.add(throwable), - ResponderLeaseHandler.None); + ResponderLeaseHandler.None, + 0); crs = new RSocketRequester( @@ -233,6 +234,7 @@ public Flux requestChannel(Publisher payloads) { StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, RequesterLeaseHandler.None); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java index daab5d246..9344d69da 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java @@ -58,6 +58,7 @@ void requesterStreamsTerminatedOnZeroErrorFrame() { StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, RequesterLeaseHandler.None); @@ -93,6 +94,7 @@ void requesterNewStreamsTerminatedAfterZeroErrorFrame() { StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, RequesterLeaseHandler.None); diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationUtilsTest.java new file mode 100644 index 000000000..4e130653d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationUtilsTest.java @@ -0,0 +1,87 @@ +package io.rsocket.fragmentation; + +import static org.junit.jupiter.api.Assertions.*; + +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameLengthFlyweight; +import io.rsocket.util.DefaultPayload; +import java.util.concurrent.ThreadLocalRandom; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +class FragmentationUtilsTest { + + @Test + void shouldValidFrameWithNoFragmentation() { + byte[] data = + new byte + [FrameLengthFlyweight.FRAME_LENGTH_MASK + - FrameLengthFlyweight.FRAME_LENGTH_SIZE + - FrameHeaderFlyweight.size()]; + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data); + + Assertions.assertThat(FragmentationUtils.isValid(0, payload)).isTrue(); + } + + @Test + void shouldValidFrameWithNoFragmentation0() { + byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK / 2]; + byte[] data = + new byte + [FrameLengthFlyweight.FRAME_LENGTH_MASK / 2 + - FrameLengthFlyweight.FRAME_LENGTH_SIZE + - FrameHeaderFlyweight.size() + - FrameHeaderFlyweight.size()]; + ThreadLocalRandom.current().nextBytes(data); + ThreadLocalRandom.current().nextBytes(metadata); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(FragmentationUtils.isValid(0, payload)).isTrue(); + } + + @Test + void shouldValidFrameWithNoFragmentation1() { + byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(FragmentationUtils.isValid(0, payload)).isFalse(); + } + + @Test + void shouldValidFrameWithNoFragmentation2() { + byte[] metadata = new byte[1]; + byte[] data = new byte[1]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(FragmentationUtils.isValid(0, payload)).isTrue(); + } + + @Test + void shouldValidFrameWithNoFragmentation3() { + byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(FragmentationUtils.isValid(64, payload)).isTrue(); + } + + @Test + void shouldValidFrameWithNoFragmentation4() { + byte[] metadata = new byte[1]; + byte[] data = new byte[1]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(FragmentationUtils.isValid(64, payload)).isTrue(); + } +} diff --git a/rsocket-transport-netty/src/test/resources/logback-test.xml b/rsocket-transport-netty/src/test/resources/logback-test.xml index a4442c6c4..f9dec2bbe 100644 --- a/rsocket-transport-netty/src/test/resources/logback-test.xml +++ b/rsocket-transport-netty/src/test/resources/logback-test.xml @@ -23,8 +23,8 @@ - - + + From c6ee7eebc6a6d4ff918f2bde8887f0755a634a44 Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Thu, 9 Apr 2020 14:34:47 +0300 Subject: [PATCH 08/11] fixes tests and changes error message Signed-off-by: Oleh Dokuka --- .../io/rsocket/core/RSocketRequester.java | 22 +++++++++++---- .../io/rsocket/core/RSocketResponder.java | 6 ++-- .../fragmentation/FragmentationUtils.java | 2 +- .../io/rsocket/core/RSocketRequesterTest.java | 4 +-- .../io/rsocket/core/RSocketResponderTest.java | 4 +-- .../fragmentation/FragmentationUtilsTest.java | 28 +++++++++++++------ 6 files changed, 45 insertions(+), 21 deletions(-) diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index e2fbd5c14..934c3b017 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -191,7 +191,9 @@ private Mono handleFireAndForget(Payload payload) { if (!FragmentationUtils.isValid(this.mtu, payload)) { payload.release(); - return Mono.error(new IllegalArgumentException("Too big Payload size")); + return Mono.error( + new IllegalArgumentException( + "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory.")); } final int streamId = streamIdSupplier.nextStreamId(receivers); @@ -220,7 +222,9 @@ private Mono handleRequestResponse(final Payload payload) { if (!FragmentationUtils.isValid(this.mtu, payload)) { payload.release(); - return Mono.error(new IllegalArgumentException("Too big Payload size")); + return Mono.error( + new IllegalArgumentException( + "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory.")); } int streamId = streamIdSupplier.nextStreamId(receivers); @@ -270,7 +274,9 @@ private Flux handleRequestStream(final Payload payload) { if (!FragmentationUtils.isValid(this.mtu, payload)) { payload.release(); - return Flux.error(new IllegalArgumentException("Too big Payload size")); + return Flux.error( + new IllegalArgumentException( + "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory.")); } int streamId = streamIdSupplier.nextStreamId(receivers); @@ -338,7 +344,8 @@ private Flux handleChannel(Flux request) { if (!FragmentationUtils.isValid(mtu, payload)) { payload.release(); final IllegalArgumentException t = - new IllegalArgumentException("Too big Payload size"); + new IllegalArgumentException( + "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory."); errorConsumer.accept(t); return Mono.error(t); } @@ -377,7 +384,8 @@ protected void hookOnNext(Payload payload) { payload.release(); cancel(); final IllegalArgumentException t = - new IllegalArgumentException("Too big Payload size"); + new IllegalArgumentException( + "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory."); errorConsumer.accept(t); // no need to send any errors. sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); @@ -472,7 +480,9 @@ private Mono handleMetadataPush(Payload payload) { if (!FragmentationUtils.isValid(this.mtu, payload)) { payload.release(); - return Mono.error(new IllegalArgumentException("Too big Payload size")); + return Mono.error( + new IllegalArgumentException( + "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory.")); } return UnicastMonoEmpty.newInstance( diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java index 36e1a8f7f..d73f3db54 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -380,7 +380,8 @@ protected void hookOnNext(Payload payload) { payload.release(); cancel(); final IllegalArgumentException t = - new IllegalArgumentException("Too big Payload size"); + new IllegalArgumentException( + "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory."); handleError(streamId, t); return; } @@ -435,7 +436,8 @@ protected void hookOnNext(Payload payload) { payload.release(); cancel(); final IllegalArgumentException t = - new IllegalArgumentException("Too big Payload size"); + new IllegalArgumentException( + "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory."); handleError(streamId, t); return; } diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationUtils.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationUtils.java index 8868c82d7..09f658589 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationUtils.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationUtils.java @@ -9,7 +9,7 @@ public final class FragmentationUtils { public static boolean isValid(int mtu, Payload payload) { return payload.hasMetadata() ? isValid(mtu, payload.data(), payload.metadata()) - : isValid(mtu, payload.metadata()); + : isValid(mtu, payload.data()); } public static boolean isValid(int mtu, ByteBuf data) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java index 7788c194a..a53b73d96 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -305,7 +305,7 @@ public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmen t -> Assertions.assertThat(t) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Too big Payload size")) + .hasMessage("The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory.")) .verify(); }); } @@ -332,7 +332,7 @@ public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmen t -> Assertions.assertThat(t) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Too big Payload size")) + .hasMessage("The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory.")) .verify(); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java index ed9d2a89a..1860a7cd0 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java @@ -153,12 +153,12 @@ public Flux requestChannel(Publisher payloads) { Assertions.assertThat(rule.errors) .first() .isInstanceOf(IllegalArgumentException.class) - .hasToString("java.lang.IllegalArgumentException: Too big Payload size"); + .hasToString("java.lang.IllegalArgumentException: The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory."); Assertions.assertThat(rule.connection.getSent()) .hasSize(1) .first() .matches(bb -> FrameHeaderFlyweight.frameType(bb) == FrameType.ERROR) - .matches(bb -> ErrorFrameFlyweight.dataUtf8(bb).contains("Too big Payload size")); + .matches(bb -> ErrorFrameFlyweight.dataUtf8(bb).contains("The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory.")); assertThat("Subscription not cancelled.", cancelled.get(), is(true)); rule.init(); diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationUtilsTest.java index 4e130653d..5c777377c 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationUtilsTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationUtilsTest.java @@ -1,7 +1,5 @@ package io.rsocket.fragmentation; -import static org.junit.jupiter.api.Assertions.*; - import io.rsocket.Payload; import io.rsocket.frame.FrameHeaderFlyweight; import io.rsocket.frame.FrameLengthFlyweight; @@ -13,7 +11,7 @@ class FragmentationUtilsTest { @Test - void shouldValidFrameWithNoFragmentation() { + void shouldBeValidFrameWithNoFragmentation() { byte[] data = new byte [FrameLengthFlyweight.FRAME_LENGTH_MASK @@ -26,7 +24,21 @@ void shouldValidFrameWithNoFragmentation() { } @Test - void shouldValidFrameWithNoFragmentation0() { + void shouldBeInValidFrameWithNoFragmentation() { + byte[] data = + new byte + [FrameLengthFlyweight.FRAME_LENGTH_MASK + - FrameLengthFlyweight.FRAME_LENGTH_SIZE + - FrameHeaderFlyweight.size() + + 1]; + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data); + + Assertions.assertThat(FragmentationUtils.isValid(0, payload)).isFalse(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation0() { byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK / 2]; byte[] data = new byte @@ -42,7 +54,7 @@ void shouldValidFrameWithNoFragmentation0() { } @Test - void shouldValidFrameWithNoFragmentation1() { + void shouldBeInValidFrameWithNoFragmentation1() { byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; ThreadLocalRandom.current().nextBytes(metadata); @@ -53,7 +65,7 @@ void shouldValidFrameWithNoFragmentation1() { } @Test - void shouldValidFrameWithNoFragmentation2() { + void shouldBeValidFrameWithNoFragmentation2() { byte[] metadata = new byte[1]; byte[] data = new byte[1]; ThreadLocalRandom.current().nextBytes(metadata); @@ -64,7 +76,7 @@ void shouldValidFrameWithNoFragmentation2() { } @Test - void shouldValidFrameWithNoFragmentation3() { + void shouldBeValidFrameWithNoFragmentation3() { byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; ThreadLocalRandom.current().nextBytes(metadata); @@ -75,7 +87,7 @@ void shouldValidFrameWithNoFragmentation3() { } @Test - void shouldValidFrameWithNoFragmentation4() { + void shouldBeValidFrameWithNoFragmentation4() { byte[] metadata = new byte[1]; byte[] data = new byte[1]; ThreadLocalRandom.current().nextBytes(metadata); From 7adafdbe2ab5bb7a245ee5424a62f8d1965bfe8d Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Thu, 9 Apr 2020 14:48:28 +0300 Subject: [PATCH 09/11] moves validator to core package and makes it pkg private Signed-off-by: Oleh Dokuka --- .../PayloadValidationUtils.java} | 4 ++-- .../java/io/rsocket/core/RSocketRequester.java | 14 +++++++------- .../java/io/rsocket/core/RSocketResponder.java | 5 ++--- .../PayloadValidationUtilsTest.java} | 18 +++++++++--------- 4 files changed, 20 insertions(+), 21 deletions(-) rename rsocket-core/src/main/java/io/rsocket/{fragmentation/FragmentationUtils.java => core/PayloadValidationUtils.java} (93%) rename rsocket-core/src/test/java/io/rsocket/{fragmentation/FragmentationUtilsTest.java => core/PayloadValidationUtilsTest.java} (82%) diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationUtils.java b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java similarity index 93% rename from rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationUtils.java rename to rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java index 09f658589..64a474814 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationUtils.java +++ b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java @@ -1,11 +1,11 @@ -package io.rsocket.fragmentation; +package io.rsocket.core; import io.netty.buffer.ByteBuf; import io.rsocket.Payload; import io.rsocket.frame.FrameHeaderFlyweight; import io.rsocket.frame.FrameLengthFlyweight; -public final class FragmentationUtils { +final class PayloadValidationUtils { public static boolean isValid(int mtu, Payload payload) { return payload.hasMetadata() ? isValid(mtu, payload.data(), payload.metadata()) diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index 934c3b017..c8cc5054c 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -189,7 +189,7 @@ private Mono handleFireAndForget(Payload payload) { return Mono.error(err); } - if (!FragmentationUtils.isValid(this.mtu, payload)) { + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { payload.release(); return Mono.error( new IllegalArgumentException( @@ -220,7 +220,7 @@ private Mono handleRequestResponse(final Payload payload) { return Mono.error(err); } - if (!FragmentationUtils.isValid(this.mtu, payload)) { + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { payload.release(); return Mono.error( new IllegalArgumentException( @@ -272,7 +272,7 @@ private Flux handleRequestStream(final Payload payload) { return Flux.error(err); } - if (!FragmentationUtils.isValid(this.mtu, payload)) { + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { payload.release(); return Flux.error( new IllegalArgumentException( @@ -341,7 +341,7 @@ private Flux handleChannel(Flux request) { (s, flux) -> { Payload payload = s.get(); if (payload != null) { - if (!FragmentationUtils.isValid(mtu, payload)) { + if (!PayloadValidationUtils.isValid(mtu, payload)) { payload.release(); final IllegalArgumentException t = new IllegalArgumentException( @@ -375,12 +375,12 @@ protected void hookOnSubscribe(Subscription subscription) { @Override protected void hookOnNext(Payload payload) { -if (first) { + if (first) { // need to skip first since we have already sent it first = false; return; } -if (!FragmentationUtils.isValid(mtu, payload)) { + if (!PayloadValidationUtils.isValid(mtu, payload)) { payload.release(); cancel(); final IllegalArgumentException t = @@ -478,7 +478,7 @@ private Mono handleMetadataPush(Payload payload) { return Mono.error(err); } - if (!FragmentationUtils.isValid(this.mtu, payload)) { + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { payload.release(); return Mono.error( new IllegalArgumentException( diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java index d73f3db54..a4aa620d1 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -25,7 +25,6 @@ import io.rsocket.RSocket; import io.rsocket.ResponderRSocket; import io.rsocket.exceptions.ApplicationErrorException; -import io.rsocket.fragmentation.FragmentationUtils; import io.rsocket.frame.*; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.SynchronizedIntObjectHashMap; @@ -376,7 +375,7 @@ protected void hookOnNext(Payload payload) { isEmpty = false; } - if (!FragmentationUtils.isValid(mtu, payload)) { + if (!PayloadValidationUtils.isValid(mtu, payload)) { payload.release(); cancel(); final IllegalArgumentException t = @@ -432,7 +431,7 @@ protected void hookOnSubscribe(Subscription s) { @Override protected void hookOnNext(Payload payload) { - if (!FragmentationUtils.isValid(mtu, payload)) { + if (!PayloadValidationUtils.isValid(mtu, payload)) { payload.release(); cancel(); final IllegalArgumentException t = diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java similarity index 82% rename from rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationUtilsTest.java rename to rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java index 5c777377c..e91fce848 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationUtilsTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java @@ -1,4 +1,4 @@ -package io.rsocket.fragmentation; +package io.rsocket.core; import io.rsocket.Payload; import io.rsocket.frame.FrameHeaderFlyweight; @@ -8,7 +8,7 @@ import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; -class FragmentationUtilsTest { +class PayloadValidationUtilsTest { @Test void shouldBeValidFrameWithNoFragmentation() { @@ -20,7 +20,7 @@ void shouldBeValidFrameWithNoFragmentation() { ThreadLocalRandom.current().nextBytes(data); final Payload payload = DefaultPayload.create(data); - Assertions.assertThat(FragmentationUtils.isValid(0, payload)).isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isTrue(); } @Test @@ -34,7 +34,7 @@ void shouldBeInValidFrameWithNoFragmentation() { ThreadLocalRandom.current().nextBytes(data); final Payload payload = DefaultPayload.create(data); - Assertions.assertThat(FragmentationUtils.isValid(0, payload)).isFalse(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isFalse(); } @Test @@ -50,7 +50,7 @@ void shouldBeValidFrameWithNoFragmentation0() { ThreadLocalRandom.current().nextBytes(metadata); final Payload payload = DefaultPayload.create(data, metadata); - Assertions.assertThat(FragmentationUtils.isValid(0, payload)).isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isTrue(); } @Test @@ -61,7 +61,7 @@ void shouldBeInValidFrameWithNoFragmentation1() { ThreadLocalRandom.current().nextBytes(data); final Payload payload = DefaultPayload.create(data, metadata); - Assertions.assertThat(FragmentationUtils.isValid(0, payload)).isFalse(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isFalse(); } @Test @@ -72,7 +72,7 @@ void shouldBeValidFrameWithNoFragmentation2() { ThreadLocalRandom.current().nextBytes(data); final Payload payload = DefaultPayload.create(data, metadata); - Assertions.assertThat(FragmentationUtils.isValid(0, payload)).isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isTrue(); } @Test @@ -83,7 +83,7 @@ void shouldBeValidFrameWithNoFragmentation3() { ThreadLocalRandom.current().nextBytes(data); final Payload payload = DefaultPayload.create(data, metadata); - Assertions.assertThat(FragmentationUtils.isValid(64, payload)).isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(64, payload)).isTrue(); } @Test @@ -94,6 +94,6 @@ void shouldBeValidFrameWithNoFragmentation4() { ThreadLocalRandom.current().nextBytes(data); final Payload payload = DefaultPayload.create(data, metadata); - Assertions.assertThat(FragmentationUtils.isValid(64, payload)).isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(64, payload)).isTrue(); } } From ab7ff18ce3e48f531073f98d0fa5c08c2fc570bf Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Thu, 9 Apr 2020 15:09:45 +0300 Subject: [PATCH 10/11] simplifies validation utils. puts error message to constant field Signed-off-by: Oleh Dokuka --- .../core/DefaultClientRSocketFactory.java | 4 +- .../core/DefaultServerRSocketFactory.java | 4 +- .../rsocket/core/PayloadValidationUtils.java | 44 +++++++++---------- .../io/rsocket/core/RSocketRequester.java | 23 +++------- .../io/rsocket/core/RSocketResponder.java | 8 ++-- .../io/rsocket/core/RSocketRequesterTest.java | 6 ++- .../io/rsocket/core/RSocketResponderTest.java | 5 ++- 7 files changed, 45 insertions(+), 49 deletions(-) diff --git a/rsocket-core/src/main/java/io/rsocket/core/DefaultClientRSocketFactory.java b/rsocket-core/src/main/java/io/rsocket/core/DefaultClientRSocketFactory.java index ce43cd1fd..b7cad7042 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/DefaultClientRSocketFactory.java +++ b/rsocket-core/src/main/java/io/rsocket/core/DefaultClientRSocketFactory.java @@ -331,6 +331,7 @@ public Mono start() { payloadDecoder, errorConsumer, StreamIdSupplier.clientSupplier(), + mtu, keepAliveTickPeriod(), keepAliveTimeout(), keepAliveHandler, @@ -379,7 +380,8 @@ public Mono start() { wrappedRSocketHandler, payloadDecoder, errorConsumer, - responderLeaseHandler); + responderLeaseHandler, + mtu); return wrappedConnection .sendOne(setupFrame) diff --git a/rsocket-core/src/main/java/io/rsocket/core/DefaultServerRSocketFactory.java b/rsocket-core/src/main/java/io/rsocket/core/DefaultServerRSocketFactory.java index f2acb9af0..85543181a 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/DefaultServerRSocketFactory.java +++ b/rsocket-core/src/main/java/io/rsocket/core/DefaultServerRSocketFactory.java @@ -281,6 +281,7 @@ private Mono acceptSetup( payloadDecoder, errorConsumer, StreamIdSupplier.serverSupplier(), + mtu, setupPayload.keepAliveInterval(), setupPayload.keepAliveMaxLifetime(), keepAliveHandler, @@ -317,7 +318,8 @@ private Mono acceptSetup( wrappedRSocketHandler, payloadDecoder, errorConsumer, - responderLeaseHandler); + responderLeaseHandler, + mtu); }) .doFinally(signalType -> setupPayload.release()) .then(); diff --git a/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java index 64a474814..3b6b375d1 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java +++ b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java @@ -1,34 +1,32 @@ package io.rsocket.core; -import io.netty.buffer.ByteBuf; import io.rsocket.Payload; import io.rsocket.frame.FrameHeaderFlyweight; import io.rsocket.frame.FrameLengthFlyweight; final class PayloadValidationUtils { - public static boolean isValid(int mtu, Payload payload) { - return payload.hasMetadata() - ? isValid(mtu, payload.data(), payload.metadata()) - : isValid(mtu, payload.data()); - } + static final String INVALID_PAYLOAD_ERROR_MESSAGE = + "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory."; - public static boolean isValid(int mtu, ByteBuf data) { - return mtu > 0 - || (((FrameHeaderFlyweight.size() - + data.readableBytes() - + FrameLengthFlyweight.FRAME_LENGTH_SIZE) - & ~FrameLengthFlyweight.FRAME_LENGTH_MASK) - == 0); - } + static boolean isValid(int mtu, Payload payload) { + if (mtu > 0) { + return true; + } - public static boolean isValid(int mtu, ByteBuf data, ByteBuf metadata) { - return mtu > 0 - || (((FrameHeaderFlyweight.size() - + FrameLengthFlyweight.FRAME_LENGTH_SIZE - + FrameHeaderFlyweight.size() - + data.readableBytes() - + metadata.readableBytes()) - & ~FrameLengthFlyweight.FRAME_LENGTH_MASK) - == 0); + if (payload.hasMetadata()) { + return (((FrameHeaderFlyweight.size() + + FrameLengthFlyweight.FRAME_LENGTH_SIZE + + FrameHeaderFlyweight.size() + + payload.data().readableBytes() + + payload.metadata().readableBytes()) + & ~FrameLengthFlyweight.FRAME_LENGTH_MASK) + == 0); + } else { + return (((FrameHeaderFlyweight.size() + + payload.data().readableBytes() + + FrameLengthFlyweight.FRAME_LENGTH_SIZE) + & ~FrameLengthFlyweight.FRAME_LENGTH_MASK) + == 0); + } } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index c8cc5054c..b3fe61214 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -16,6 +16,7 @@ package io.rsocket.core; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; import static io.rsocket.keepalive.KeepAliveSupport.ClientKeepAliveSupport; import static io.rsocket.keepalive.KeepAliveSupport.KeepAlive; @@ -191,9 +192,7 @@ private Mono handleFireAndForget(Payload payload) { if (!PayloadValidationUtils.isValid(this.mtu, payload)) { payload.release(); - return Mono.error( - new IllegalArgumentException( - "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory.")); + return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); } final int streamId = streamIdSupplier.nextStreamId(receivers); @@ -222,9 +221,7 @@ private Mono handleRequestResponse(final Payload payload) { if (!PayloadValidationUtils.isValid(this.mtu, payload)) { payload.release(); - return Mono.error( - new IllegalArgumentException( - "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory.")); + return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); } int streamId = streamIdSupplier.nextStreamId(receivers); @@ -274,9 +271,7 @@ private Flux handleRequestStream(final Payload payload) { if (!PayloadValidationUtils.isValid(this.mtu, payload)) { payload.release(); - return Flux.error( - new IllegalArgumentException( - "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory.")); + return Flux.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); } int streamId = streamIdSupplier.nextStreamId(receivers); @@ -344,8 +339,7 @@ private Flux handleChannel(Flux request) { if (!PayloadValidationUtils.isValid(mtu, payload)) { payload.release(); final IllegalArgumentException t = - new IllegalArgumentException( - "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory."); + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); errorConsumer.accept(t); return Mono.error(t); } @@ -384,8 +378,7 @@ protected void hookOnNext(Payload payload) { payload.release(); cancel(); final IllegalArgumentException t = - new IllegalArgumentException( - "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory."); + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); errorConsumer.accept(t); // no need to send any errors. sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); @@ -480,9 +473,7 @@ private Mono handleMetadataPush(Payload payload) { if (!PayloadValidationUtils.isValid(this.mtu, payload)) { payload.release(); - return Mono.error( - new IllegalArgumentException( - "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory.")); + return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); } return UnicastMonoEmpty.newInstance( diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java index a4aa620d1..6f235587a 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -16,6 +16,8 @@ package io.rsocket.core; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.util.ReferenceCountUtil; @@ -379,8 +381,7 @@ protected void hookOnNext(Payload payload) { payload.release(); cancel(); final IllegalArgumentException t = - new IllegalArgumentException( - "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory."); + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); handleError(streamId, t); return; } @@ -435,8 +436,7 @@ protected void hookOnNext(Payload payload) { payload.release(); cancel(); final IllegalArgumentException t = - new IllegalArgumentException( - "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory."); + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); handleError(streamId, t); return; } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java index a53b73d96..08fd4718c 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -16,6 +16,7 @@ package io.rsocket.core; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; import static io.rsocket.frame.FrameHeaderFlyweight.frameType; import static io.rsocket.frame.FrameType.CANCEL; import static io.rsocket.frame.FrameType.KEEPALIVE; @@ -37,6 +38,7 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.util.CharsetUtil; import io.rsocket.Payload; +import io.rsocket.RSocket; import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.exceptions.RejectedSetupException; import io.rsocket.frame.CancelFrameFlyweight; @@ -305,7 +307,7 @@ public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmen t -> Assertions.assertThat(t) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory.")) + .hasMessage(INVALID_PAYLOAD_ERROR_MESSAGE)) .verify(); }); } @@ -332,7 +334,7 @@ public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmen t -> Assertions.assertThat(t) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory.")) + .hasMessage(INVALID_PAYLOAD_ERROR_MESSAGE)) .verify(); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java index 1860a7cd0..5c147f46f 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java @@ -16,6 +16,7 @@ package io.rsocket.core; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; import static io.rsocket.frame.FrameHeaderFlyweight.frameType; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.*; @@ -153,12 +154,12 @@ public Flux requestChannel(Publisher payloads) { Assertions.assertThat(rule.errors) .first() .isInstanceOf(IllegalArgumentException.class) - .hasToString("java.lang.IllegalArgumentException: The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory."); + .hasToString("java.lang.IllegalArgumentException: " + INVALID_PAYLOAD_ERROR_MESSAGE); Assertions.assertThat(rule.connection.getSent()) .hasSize(1) .first() .matches(bb -> FrameHeaderFlyweight.frameType(bb) == FrameType.ERROR) - .matches(bb -> ErrorFrameFlyweight.dataUtf8(bb).contains("The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory.")); + .matches(bb -> ErrorFrameFlyweight.dataUtf8(bb).contains(INVALID_PAYLOAD_ERROR_MESSAGE)); assertThat("Subscription not cancelled.", cancelled.get(), is(true)); rule.init(); From f8676af553314c6e13319ed804428dd0ec1ec668 Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Thu, 9 Apr 2020 20:02:50 +0300 Subject: [PATCH 11/11] fixes IDE warning Signed-off-by: Oleh Dokuka --- .../main/java/io/rsocket/core/RSocketRequester.java | 1 + .../java/io/rsocket/core/RSocketRequesterTest.java | 11 +++-------- .../src/main/java/io/rsocket/test/TransportTest.java | 3 +-- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index b3fe61214..fc3175b15 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -486,6 +486,7 @@ private Mono handleMetadataPush(Payload payload) { }); } + @Nullable private Throwable checkAvailable() { Throwable err = this.terminationError; if (err != null) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java index 08fd4718c..101500da7 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -277,17 +277,12 @@ protected void hookOnSubscribe(Subscription subscription) {} ByteBuf initialFrame = iterator.next(); Assertions.assertThat(FrameHeaderFlyweight.frameType(initialFrame)).isEqualTo(REQUEST_CHANNEL); - Assertions.assertThat(RequestChannelFrameFlyweight.initialRequestN(initialFrame)).isEqualTo(1); + Assertions.assertThat(RequestChannelFrameFlyweight.initialRequestN(initialFrame)) + .isEqualTo(Integer.MAX_VALUE); Assertions.assertThat( RequestChannelFrameFlyweight.data(initialFrame).toString(CharsetUtil.UTF_8)) .isEqualTo("0"); - ByteBuf requestNFrame = iterator.next(); - - Assertions.assertThat(FrameHeaderFlyweight.frameType(requestNFrame)).isEqualTo(REQUEST_N); - Assertions.assertThat(RequestNFrameFlyweight.requestN(requestNFrame)) - .isEqualTo(Integer.MAX_VALUE); - Assertions.assertThat(iterator.hasNext()).isFalse(); } @@ -329,7 +324,7 @@ public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmen RequestNFrameFlyweight.encode( ByteBufAllocator.DEFAULT, rule.getStreamIdForRequestType(REQUEST_CHANNEL), - 1))) + 2))) .expectErrorSatisfies( t -> Assertions.assertThat(t) diff --git a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java index 56ea60feb..1c00b0502 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java @@ -219,8 +219,7 @@ default void requestChannel2_000_000() { default void requestChannel3() { AtomicLong requested = new AtomicLong(); Flux payloads = - Flux.range(0, 3).doOnRequest(requested::addAndGet).map(this::createTestPayload); - + Flux.range(0, 3).doOnRequest(requested::addAndGet).map(this::createTestPayload); getClient() .requestChannel(payloads)