diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index 2ecdec215..a249ea888 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -64,7 +64,6 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; -import reactor.core.publisher.Operators; import reactor.core.publisher.SignalType; import reactor.core.publisher.UnicastProcessor; import reactor.core.scheduler.Scheduler; @@ -267,68 +266,54 @@ private Mono handleRequestResponse(final Payload payload) { final UnboundedProcessor sendProcessor = this.sendProcessor; final UnicastProcessor receiver = UnicastProcessor.create(Queues.one().get()); - final AtomicBoolean once = new AtomicBoolean(); + return Mono.fromDirect( + new RequestOperator( + receiver.next(), "RequestResponseMono allows only a single subscriber") { - return Mono.defer( - () -> { - if (once.getAndSet(true)) { - return Mono.error( - new IllegalStateException("RequestResponseMono allows only a single subscriber")); - } + @Override + void hookOnFirstRequest(long n) { + if (isDisposed()) { + payload.release(); + final Throwable t = terminationError; + receiver.onError(t); + return; + } - return receiver - .next() - .transform( - Operators.lift( - (s, actual) -> - new RequestOperator(actual) { - - @Override - void hookOnFirstRequest(long n) { - if (isDisposed()) { - payload.release(); - final Throwable t = terminationError; - receiver.onError(t); - return; - } - - RequesterLeaseHandler lh = leaseHandler; - if (!lh.useLease()) { - payload.release(); - receiver.onError(lh.leaseError()); - return; - } - - int streamId = streamIdSupplier.nextStreamId(receivers); - this.streamId = streamId; - - ByteBuf requestResponseFrame = - RequestResponseFrameCodec.encodeReleasingPayload( - allocator, streamId, payload); - - receivers.put(streamId, receiver); - sendProcessor.onNext(requestResponseFrame); - } - - @Override - void hookOnCancel() { - if (receivers.remove(streamId, receiver)) { - sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); - } else { - if (this.firstRequest) { - payload.release(); - } - } - } - - @Override - public void hookOnTerminal(SignalType signalType) { - receivers.remove(streamId, receiver); - } - })) - .subscribeOn(serialScheduler) - .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); - }); + RequesterLeaseHandler lh = leaseHandler; + if (!lh.useLease()) { + payload.release(); + receiver.onError(lh.leaseError()); + return; + } + + int streamId = streamIdSupplier.nextStreamId(receivers); + this.streamId = streamId; + + ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(allocator, streamId, payload); + + receivers.put(streamId, receiver); + sendProcessor.onNext(requestResponseFrame); + } + + @Override + void hookOnCancel() { + if (receivers.remove(streamId, receiver)) { + sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); + } else { + if (this.firstRequest) { + payload.release(); + } + } + } + + @Override + public void hookOnTerminal(SignalType signalType) { + receivers.remove(streamId, receiver); + } + }) + .subscribeOn(serialScheduler) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); } private Flux handleRequestStream(final Payload payload) { @@ -348,79 +333,65 @@ private Flux handleRequestStream(final Payload payload) { } final UnboundedProcessor sendProcessor = this.sendProcessor; - final UnicastProcessor receiver = UnicastProcessor.create(); - final AtomicBoolean once = new AtomicBoolean(); + final UnicastProcessor receiver = UnicastProcessor.create(Queues.one().get()); - return Flux.defer( - () -> { - if (once.getAndSet(true)) { - return Flux.error( - new IllegalStateException("RequestStreamFlux allows only a single subscriber")); - } + return Flux.from( + new RequestOperator(receiver, "RequestStreamFlux allows only a single subscriber") { - return receiver - .transform( - Operators.lift( - (s, actual) -> - new RequestOperator(actual) { - - @Override - void hookOnFirstRequest(long n) { - if (isDisposed()) { - payload.release(); - final Throwable t = terminationError; - receiver.onError(t); - return; - } - - RequesterLeaseHandler lh = leaseHandler; - if (!lh.useLease()) { - payload.release(); - receiver.onError(lh.leaseError()); - return; - } - - int streamId = streamIdSupplier.nextStreamId(receivers); - this.streamId = streamId; - - ByteBuf requestStreamFrame = - RequestStreamFrameCodec.encodeReleasingPayload( - allocator, streamId, n, payload); - - receivers.put(streamId, receiver); - - sendProcessor.onNext(requestStreamFrame); - } - - @Override - void hookOnRemainingRequests(long n) { - if (receiver.isDisposed()) { - return; - } - - sendProcessor.onNext( - RequestNFrameCodec.encode(allocator, streamId, n)); - } - - @Override - void hookOnCancel() { - if (receivers.remove(streamId, receiver)) { - sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); - } else { - if (this.firstRequest) { - payload.release(); - } - } - } - - @Override - void hookOnTerminal(SignalType signalType) { - receivers.remove(streamId); - } - })) - .subscribeOn(serialScheduler, false) - .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); - }); + @Override + void hookOnFirstRequest(long n) { + if (isDisposed()) { + payload.release(); + final Throwable t = terminationError; + receiver.onError(t); + return; + } + + RequesterLeaseHandler lh = leaseHandler; + if (!lh.useLease()) { + payload.release(); + receiver.onError(lh.leaseError()); + return; + } + + int streamId = streamIdSupplier.nextStreamId(receivers); + this.streamId = streamId; + + ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(allocator, streamId, n, payload); + + receivers.put(streamId, receiver); + + sendProcessor.onNext(requestStreamFrame); + } + + @Override + void hookOnRemainingRequests(long n) { + if (receiver.isDisposed()) { + return; + } + + sendProcessor.onNext(RequestNFrameCodec.encode(allocator, streamId, n)); + } + + @Override + void hookOnCancel() { + if (receivers.remove(streamId, receiver)) { + sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); + } else { + if (this.firstRequest) { + payload.release(); + } + } + } + + @Override + void hookOnTerminal(SignalType signalType) { + receivers.remove(streamId); + } + }) + .subscribeOn(serialScheduler, false) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); } private Flux handleChannel(Flux request) { @@ -456,137 +427,135 @@ private Flux handleChannel(Flux request) { private Flux handleChannel(Payload initialPayload, Flux inboundFlux) { final UnboundedProcessor sendProcessor = this.sendProcessor; - final UnicastProcessor receiver = UnicastProcessor.create(); - - return receiver - .transform( - Operators.lift( - (s, actual) -> - new RequestOperator(actual) { - - final BaseSubscriber upstreamSubscriber = - new BaseSubscriber() { - - boolean first = true; - - @Override - protected void hookOnSubscribe(Subscription subscription) { - // noops - } - - @Override - protected void hookOnNext(Payload payload) { - if (first) { - // need to skip first since we have already sent it - // no need to release it since it was released earlier on the - // request - // establishment - // phase - first = false; - request(1); - return; - } - if (!PayloadValidationUtils.isValid(mtu, payload, maxFrameLength)) { - payload.release(); - cancel(); - final IllegalArgumentException t = - new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); - // no need to send any errors. - sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); - receiver.onError(t); - return; - } - final ByteBuf frame = - PayloadFrameCodec.encodeNextReleasingPayload( - allocator, streamId, payload); - - sendProcessor.onNext(frame); - } - - @Override - protected void hookOnComplete() { - ByteBuf frame = PayloadFrameCodec.encodeComplete(allocator, streamId); - sendProcessor.onNext(frame); - } - - @Override - protected void hookOnError(Throwable t) { - ByteBuf frame = ErrorFrameCodec.encode(allocator, streamId, t); - sendProcessor.onNext(frame); - receiver.onError(t); - } - - @Override - protected void hookFinally(SignalType type) { - senders.remove(streamId, this); - } - }; - - @Override - void hookOnFirstRequest(long n) { - if (isDisposed()) { - initialPayload.release(); - final Throwable t = terminationError; - upstreamSubscriber.cancel(); - receiver.onError(t); - return; - } - - RequesterLeaseHandler lh = leaseHandler; - if (!lh.useLease()) { - initialPayload.release(); - receiver.onError(lh.leaseError()); - return; - } - - final int streamId = streamIdSupplier.nextStreamId(receivers); - this.streamId = streamId; - - final ByteBuf frame = - RequestChannelFrameCodec.encodeReleasingPayload( - allocator, streamId, false, n, initialPayload); - - senders.put(streamId, upstreamSubscriber); - receivers.put(streamId, receiver); - - inboundFlux - .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER) - .subscribe(upstreamSubscriber); - - sendProcessor.onNext(frame); + final UnicastProcessor receiver = UnicastProcessor.create(Queues.one().get()); + + return Flux.from( + new RequestOperator( + receiver, "RequestStreamFlux allows only a " + "single subscriber") { + + final BaseSubscriber upstreamSubscriber = + new BaseSubscriber() { + + boolean first = true; + + @Override + protected void hookOnSubscribe(Subscription subscription) { + // noops + } + + @Override + protected void hookOnNext(Payload payload) { + if (first) { + // need to skip first since we have already sent it + // no need to release it since it was released earlier on the + // request + // establishment + // phase + first = false; + request(1); + return; } + if (!PayloadValidationUtils.isValid(mtu, payload, maxFrameLength)) { + payload.release(); + cancel(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + // no need to send any errors. + sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); + receiver.onError(t); + return; + } + final ByteBuf frame = + PayloadFrameCodec.encodeNextReleasingPayload( + allocator, streamId, payload); + + sendProcessor.onNext(frame); + } + + @Override + protected void hookOnComplete() { + ByteBuf frame = PayloadFrameCodec.encodeComplete(allocator, streamId); + sendProcessor.onNext(frame); + } + + @Override + protected void hookOnError(Throwable t) { + ByteBuf frame = ErrorFrameCodec.encode(allocator, streamId, t); + sendProcessor.onNext(frame); + receiver.onError(t); + } + + @Override + protected void hookFinally(SignalType type) { + senders.remove(streamId, this); + } + }; + + @Override + void hookOnFirstRequest(long n) { + if (isDisposed()) { + initialPayload.release(); + final Throwable t = terminationError; + upstreamSubscriber.cancel(); + receiver.onError(t); + return; + } - @Override - void hookOnRemainingRequests(long n) { - if (receiver.isDisposed()) { - return; - } + RequesterLeaseHandler lh = leaseHandler; + if (!lh.useLease()) { + initialPayload.release(); + receiver.onError(lh.leaseError()); + return; + } - sendProcessor.onNext(RequestNFrameCodec.encode(allocator, streamId, n)); - } + final int streamId = streamIdSupplier.nextStreamId(receivers); + this.streamId = streamId; - @Override - void hookOnCancel() { - senders.remove(streamId, upstreamSubscriber); - if (receivers.remove(streamId, receiver)) { - sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); - } - } + final ByteBuf frame = + RequestChannelFrameCodec.encodeReleasingPayload( + allocator, streamId, false, n, initialPayload); - @Override - void hookOnTerminal(SignalType signalType) { - if (signalType == SignalType.ON_ERROR) { - upstreamSubscriber.cancel(); - } - receivers.remove(streamId, receiver); - } + senders.put(streamId, upstreamSubscriber); + receivers.put(streamId, receiver); - @Override - public void cancel() { - upstreamSubscriber.cancel(); - super.cancel(); - } - })) + inboundFlux + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER) + .subscribe(upstreamSubscriber); + + sendProcessor.onNext(frame); + } + + @Override + void hookOnRemainingRequests(long n) { + if (receiver.isDisposed()) { + return; + } + + sendProcessor.onNext(RequestNFrameCodec.encode(allocator, streamId, n)); + } + + @Override + void hookOnCancel() { + senders.remove(streamId, upstreamSubscriber); + if (receivers.remove(streamId, receiver)) { + sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); + } + } + + @Override + void hookOnTerminal(SignalType signalType) { + if (signalType == SignalType.ON_ERROR) { + upstreamSubscriber.cancel(); + } + receivers.remove(streamId, receiver); + } + + @Override + public void cancel() { + upstreamSubscriber.cancel(); + super.cancel(); + } + }) .subscribeOn(serialScheduler, false); } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java index 581605ff4..3e2c06e92 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,6 +48,7 @@ import reactor.core.Exceptions; import reactor.core.publisher.*; import reactor.util.annotation.Nullable; +import reactor.util.concurrent.Queues; /** Responder side of RSocket. Receives {@link ByteBuf}s from a peer's {@link RSocketRequester} */ class RSocketResponder implements RSocket { @@ -537,7 +538,7 @@ protected void hookOnError(Throwable throwable) { } private void handleChannel(int streamId, Payload payload, long initialRequestN) { - UnicastProcessor frames = UnicastProcessor.create(); + UnicastProcessor frames = UnicastProcessor.create(Queues.one().get()); channelProcessors.put(streamId, frames); Flux payloads = diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestOperator.java b/rsocket-core/src/main/java/io/rsocket/core/RequestOperator.java index 6123b0492..dbca5fef2 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestOperator.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestOperator.java @@ -1,14 +1,17 @@ package io.rsocket.core; import io.rsocket.Payload; -import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import reactor.core.CorePublisher; import reactor.core.CoreSubscriber; import reactor.core.Fuseable; import reactor.core.publisher.Operators; import reactor.core.publisher.SignalType; import reactor.util.context.Context; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + /** * This is a support class for handling of request input, intended for use with {@link * Operators#lift}. It ensures serial execution of cancellation vs first request signals and also @@ -16,9 +19,14 @@ * invocations. */ abstract class RequestOperator - implements CoreSubscriber, Fuseable.QueueSubscription { + implements CoreSubscriber, + CorePublisher, + Fuseable.QueueSubscription, + Fuseable { - final CoreSubscriber actual; + final String errorMessageOnSecondSubscription; + + CoreSubscriber actual; Subscription s; Fuseable.QueueSubscription qs; @@ -30,8 +38,25 @@ abstract class RequestOperator static final AtomicIntegerFieldUpdater WIP = AtomicIntegerFieldUpdater.newUpdater(RequestOperator.class, "wip"); - RequestOperator(CoreSubscriber actual) { - this.actual = actual; + RequestOperator(CorePublisher source, String errorMessageOnSecondSubscription) { + this.errorMessageOnSecondSubscription = errorMessageOnSecondSubscription; + source.subscribe(this); + WIP.lazySet(this, -1); + } + + @Override + public void subscribe(Subscriber actual) { + subscribe(Operators.toCoreSubscriber(actual)); + } + + @Override + public void subscribe(CoreSubscriber actual) { + if (this.wip == -1 && WIP.compareAndSet(this, -1, 0)) { + this.actual = actual; + actual.onSubscribe(this); + } else { + Operators.error(actual, new IllegalStateException(this.errorMessageOnSecondSubscription)); + } } /** @@ -129,7 +154,6 @@ public void onSubscribe(Subscription s) { if (s instanceof Fuseable.QueueSubscription) { this.qs = (Fuseable.QueueSubscription) s; } - this.actual.onSubscribe(this); } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java index 1e7bb337f..d78a1d032 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java @@ -16,8 +16,6 @@ package io.rsocket.core; -import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; - import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.Payload; @@ -33,10 +31,6 @@ import io.rsocket.test.util.LocalDuplexConnection; import io.rsocket.util.DefaultPayload; import io.rsocket.util.EmptyPayload; -import java.time.Duration; -import java.util.List; -import java.util.concurrent.CancellationException; -import java.util.concurrent.atomic.AtomicReference; import org.assertj.core.api.Assertions; import org.junit.Rule; import org.junit.Test; @@ -52,6 +46,13 @@ import reactor.test.StepVerifier; import reactor.test.publisher.TestPublisher; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicReference; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + public class RSocketTest { @Rule public final SocketRule rule = new SocketRule(); @@ -158,13 +159,13 @@ public Flux requestChannel(Publisher payloads) { } @Test(timeout = 2000) - public void testStream() throws Exception { + public void testStream() { Flux responses = rule.crs.requestStream(DefaultPayload.create("Payload In")); StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); } @Test(timeout = 2000) - public void testChannel() throws Exception { + public void testChannel() { Flux requests = Flux.range(0, 10).map(i -> DefaultPayload.create("streaming in -> " + i)); Flux responses = rule.crs.requestChannel(requests);