diff --git a/rsocket-core/src/main/java/io/rsocket/core/DefaultClientRSocketFactory.java b/rsocket-core/src/main/java/io/rsocket/core/DefaultClientRSocketFactory.java index ce43cd1fd..b7cad7042 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/DefaultClientRSocketFactory.java +++ b/rsocket-core/src/main/java/io/rsocket/core/DefaultClientRSocketFactory.java @@ -331,6 +331,7 @@ public Mono start() { payloadDecoder, errorConsumer, StreamIdSupplier.clientSupplier(), + mtu, keepAliveTickPeriod(), keepAliveTimeout(), keepAliveHandler, @@ -379,7 +380,8 @@ public Mono start() { wrappedRSocketHandler, payloadDecoder, errorConsumer, - responderLeaseHandler); + responderLeaseHandler, + mtu); return wrappedConnection .sendOne(setupFrame) diff --git a/rsocket-core/src/main/java/io/rsocket/core/DefaultServerRSocketFactory.java b/rsocket-core/src/main/java/io/rsocket/core/DefaultServerRSocketFactory.java index f2acb9af0..85543181a 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/DefaultServerRSocketFactory.java +++ b/rsocket-core/src/main/java/io/rsocket/core/DefaultServerRSocketFactory.java @@ -281,6 +281,7 @@ private Mono acceptSetup( payloadDecoder, errorConsumer, StreamIdSupplier.serverSupplier(), + mtu, setupPayload.keepAliveInterval(), setupPayload.keepAliveMaxLifetime(), keepAliveHandler, @@ -317,7 +318,8 @@ private Mono acceptSetup( wrappedRSocketHandler, payloadDecoder, errorConsumer, - responderLeaseHandler); + responderLeaseHandler, + mtu); }) .doFinally(signalType -> setupPayload.release()) .then(); diff --git a/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java new file mode 100644 index 000000000..3b6b375d1 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java @@ -0,0 +1,32 @@ +package io.rsocket.core; + +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameLengthFlyweight; + +final class PayloadValidationUtils { + static final String INVALID_PAYLOAD_ERROR_MESSAGE = + "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory."; + + static boolean isValid(int mtu, Payload payload) { + if (mtu > 0) { + return true; + } + + if (payload.hasMetadata()) { + return (((FrameHeaderFlyweight.size() + + FrameLengthFlyweight.FRAME_LENGTH_SIZE + + FrameHeaderFlyweight.size() + + payload.data().readableBytes() + + payload.metadata().readableBytes()) + & ~FrameLengthFlyweight.FRAME_LENGTH_MASK) + == 0); + } else { + return (((FrameHeaderFlyweight.size() + + payload.data().readableBytes() + + FrameLengthFlyweight.FRAME_LENGTH_SIZE) + & ~FrameLengthFlyweight.FRAME_LENGTH_MASK) + == 0); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index 6c26361a2..fc3175b15 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -16,6 +16,7 @@ package io.rsocket.core; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; import static io.rsocket.keepalive.KeepAliveSupport.ClientKeepAliveSupport; import static io.rsocket.keepalive.KeepAliveSupport.KeepAlive; @@ -88,6 +89,7 @@ class RSocketRequester implements RSocket { private final IntObjectMap senders; private final IntObjectMap> receivers; private final UnboundedProcessor sendProcessor; + private final int mtu; private final RequesterLeaseHandler leaseHandler; private final ByteBufAllocator allocator; private final KeepAliveFramesAcceptor keepAliveFramesAcceptor; @@ -99,6 +101,7 @@ class RSocketRequester implements RSocket { PayloadDecoder payloadDecoder, Consumer errorConsumer, StreamIdSupplier streamIdSupplier, + int mtu, int keepAliveTickPeriod, int keepAliveAckTimeout, @Nullable KeepAliveHandler keepAliveHandler, @@ -108,6 +111,7 @@ class RSocketRequester implements RSocket { this.payloadDecoder = payloadDecoder; this.errorConsumer = errorConsumer; this.streamIdSupplier = streamIdSupplier; + this.mtu = mtu; this.leaseHandler = leaseHandler; this.senders = new SynchronizedIntObjectHashMap<>(); this.receivers = new SynchronizedIntObjectHashMap<>(); @@ -186,6 +190,11 @@ private Mono handleFireAndForget(Payload payload) { return Mono.error(err); } + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + } + final int streamId = streamIdSupplier.nextStreamId(receivers); return UnicastMonoEmpty.newInstance( @@ -210,6 +219,11 @@ private Mono handleRequestResponse(final Payload payload) { return Mono.error(err); } + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + } + int streamId = streamIdSupplier.nextStreamId(receivers); final UnboundedProcessor sendProcessor = this.sendProcessor; @@ -255,6 +269,11 @@ private Flux handleRequestStream(final Payload payload) { return Flux.error(err); } + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Flux.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + } + int streamId = streamIdSupplier.nextStreamId(receivers); final UnboundedProcessor sendProcessor = this.sendProcessor; @@ -317,6 +336,13 @@ private Flux handleChannel(Flux request) { (s, flux) -> { Payload payload = s.get(); if (payload != null) { + if (!PayloadValidationUtils.isValid(mtu, payload)) { + payload.release(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + errorConsumer.accept(t); + return Mono.error(t); + } return handleChannel(payload, flux); } else { return flux; @@ -348,6 +374,17 @@ protected void hookOnNext(Payload payload) { first = false; return; } + if (!PayloadValidationUtils.isValid(mtu, payload)) { + payload.release(); + cancel(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + errorConsumer.accept(t); + // no need to send any errors. + sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); + receiver.onError(t); + return; + } final ByteBuf frame = PayloadFrameFlyweight.encode(allocator, streamId, false, false, true, payload); @@ -434,6 +471,11 @@ private Mono handleMetadataPush(Payload payload) { return Mono.error(err); } + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + } + return UnicastMonoEmpty.newInstance( () -> { ByteBuf metadataPushFrame = @@ -444,6 +486,7 @@ private Mono handleMetadataPush(Payload payload) { }); } + @Nullable private Throwable checkAvailable() { Throwable err = this.terminationError; if (err != null) { diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java index de6e8ad23..6f235587a 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -16,6 +16,8 @@ package io.rsocket.core; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.util.ReferenceCountUtil; @@ -51,6 +53,8 @@ class RSocketResponder implements ResponderRSocket { private final Consumer errorConsumer; private final ResponderLeaseHandler leaseHandler; + private final int mtu; + private final IntObjectMap sendingSubscriptions; private final IntObjectMap> channelProcessors; @@ -63,9 +67,11 @@ class RSocketResponder implements ResponderRSocket { RSocket requestHandler, PayloadDecoder payloadDecoder, Consumer errorConsumer, - ResponderLeaseHandler leaseHandler) { + ResponderLeaseHandler leaseHandler, + int mtu) { this.allocator = allocator; this.connection = connection; + this.mtu = mtu; this.requestHandler = requestHandler; this.responderRSocket = @@ -371,6 +377,15 @@ protected void hookOnNext(Payload payload) { isEmpty = false; } + if (!PayloadValidationUtils.isValid(mtu, payload)) { + payload.release(); + cancel(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + handleError(streamId, t); + return; + } + ByteBuf byteBuf; try { byteBuf = PayloadFrameFlyweight.encodeNextComplete(allocator, streamId, payload); @@ -417,6 +432,15 @@ protected void hookOnSubscribe(Subscription s) { @Override protected void hookOnNext(Payload payload) { + if (!PayloadValidationUtils.isValid(mtu, payload)) { + payload.release(); + cancel(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + handleError(streamId, t); + return; + } + ByteBuf byteBuf; try { byteBuf = PayloadFrameFlyweight.encodeNext(allocator, streamId, payload); diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java index d634f7374..e59ece86f 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java @@ -20,7 +20,14 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.netty.util.ReferenceCountUtil; -import io.rsocket.frame.*; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameLengthFlyweight; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameFlyweight; +import io.rsocket.frame.RequestChannelFrameFlyweight; +import io.rsocket.frame.RequestFireAndForgetFrameFlyweight; +import io.rsocket.frame.RequestResponseFrameFlyweight; +import io.rsocket.frame.RequestStreamFrameFlyweight; import java.util.function.Consumer; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; diff --git a/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java index 6cb05dec1..10725238a 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java @@ -61,6 +61,7 @@ static RSocketState requester(int tickPeriod, int timeout) { DefaultPayload::create, errors, StreamIdSupplier.clientSupplier(), + 0, tickPeriod, timeout, new DefaultKeepAliveHandler(connection), @@ -86,6 +87,7 @@ static ResumableRSocketState resumableRequester(int tickPeriod, int timeout) { DefaultPayload::create, errors, StreamIdSupplier.clientSupplier(), + 0, tickPeriod, timeout, new ResumableKeepAliveHandler(resumableConnection), diff --git a/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java new file mode 100644 index 000000000..e91fce848 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java @@ -0,0 +1,99 @@ +package io.rsocket.core; + +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameLengthFlyweight; +import io.rsocket.util.DefaultPayload; +import java.util.concurrent.ThreadLocalRandom; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +class PayloadValidationUtilsTest { + + @Test + void shouldBeValidFrameWithNoFragmentation() { + byte[] data = + new byte + [FrameLengthFlyweight.FRAME_LENGTH_MASK + - FrameLengthFlyweight.FRAME_LENGTH_SIZE + - FrameHeaderFlyweight.size()]; + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isTrue(); + } + + @Test + void shouldBeInValidFrameWithNoFragmentation() { + byte[] data = + new byte + [FrameLengthFlyweight.FRAME_LENGTH_MASK + - FrameLengthFlyweight.FRAME_LENGTH_SIZE + - FrameHeaderFlyweight.size() + + 1]; + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isFalse(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation0() { + byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK / 2]; + byte[] data = + new byte + [FrameLengthFlyweight.FRAME_LENGTH_MASK / 2 + - FrameLengthFlyweight.FRAME_LENGTH_SIZE + - FrameHeaderFlyweight.size() + - FrameHeaderFlyweight.size()]; + ThreadLocalRandom.current().nextBytes(data); + ThreadLocalRandom.current().nextBytes(metadata); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isTrue(); + } + + @Test + void shouldBeInValidFrameWithNoFragmentation1() { + byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isFalse(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation2() { + byte[] metadata = new byte[1]; + byte[] data = new byte[1]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isTrue(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation3() { + byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(64, payload)).isTrue(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation4() { + byte[] metadata = new byte[1]; + byte[] data = new byte[1]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(64, payload)).isTrue(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java index 3cbb3c5d7..0a7f7a196 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java @@ -94,6 +94,7 @@ void setUp() { StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, requesterLeaseHandler); @@ -111,7 +112,8 @@ void setUp() { mockRSocketHandler, payloadDecoder, err -> {}, - responderLeaseHandler); + responderLeaseHandler, + 0); } @Test diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java index 8a2e114cc..8380290f2 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java @@ -67,6 +67,7 @@ void setUp() { StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, RequesterLeaseHandler.None); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java index 20b1825fa..101500da7 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -16,10 +16,21 @@ package io.rsocket.core; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; import static io.rsocket.frame.FrameHeaderFlyweight.frameType; -import static io.rsocket.frame.FrameType.*; +import static io.rsocket.frame.FrameType.CANCEL; +import static io.rsocket.frame.FrameType.KEEPALIVE; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.*; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.verify; @@ -27,9 +38,18 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.util.CharsetUtil; import io.rsocket.Payload; +import io.rsocket.RSocket; import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.exceptions.RejectedSetupException; -import io.rsocket.frame.*; +import io.rsocket.frame.CancelFrameFlyweight; +import io.rsocket.frame.ErrorFrameFlyweight; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameLengthFlyweight; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameFlyweight; +import io.rsocket.frame.RequestChannelFrameFlyweight; +import io.rsocket.frame.RequestNFrameFlyweight; +import io.rsocket.frame.RequestStreamFrameFlyweight; import io.rsocket.lease.RequesterLeaseHandler; import io.rsocket.test.util.TestSubscriber; import io.rsocket.util.DefaultPayload; @@ -39,7 +59,10 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.BiFunction; import java.util.stream.Collectors; +import java.util.stream.Stream; import org.assertj.core.api.Assertions; import org.junit.Rule; import org.junit.Test; @@ -51,6 +74,7 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; import reactor.core.publisher.UnicastProcessor; +import reactor.test.StepVerifier; public class RSocketRequesterTest { @@ -262,6 +286,62 @@ protected void hookOnSubscribe(Subscription subscription) {} Assertions.assertThat(iterator.hasNext()).isFalse(); } + @Test + public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation() { + prepareCalls() + .forEach( + generator -> { + byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + StepVerifier.create( + generator.apply(rule.socket, DefaultPayload.create(data, metadata))) + .expectSubscription() + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage(INVALID_PAYLOAD_ERROR_MESSAGE)) + .verify(); + }); + } + + @Test + public void + shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentationForRequestChannelCase() { + byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + StepVerifier.create( + rule.socket.requestChannel( + Flux.just(EmptyPayload.INSTANCE, DefaultPayload.create(data, metadata)))) + .expectSubscription() + .then( + () -> + rule.connection.addToReceivedBuffer( + RequestNFrameFlyweight.encode( + ByteBufAllocator.DEFAULT, + rule.getStreamIdForRequestType(REQUEST_CHANNEL), + 2))) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage(INVALID_PAYLOAD_ERROR_MESSAGE)) + .verify(); + } + + static Stream>> prepareCalls() { + return Stream.of( + RSocket::fireAndForget, + RSocket::requestResponse, + RSocket::requestStream, + (rSocket, payload) -> rSocket.requestChannel(Flux.just(payload)), + RSocket::metadataPush); + } + public int sendRequestResponse(Publisher response) { Subscriber sub = TestSubscriber.create(); response.subscribe(sub); @@ -285,6 +365,7 @@ protected RSocketRequester newRSocket() { StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, RequesterLeaseHandler.None); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java index 10157532a..5c147f46f 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java @@ -16,6 +16,7 @@ package io.rsocket.core; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; import static io.rsocket.frame.FrameHeaderFlyweight.frameType; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.*; @@ -34,11 +35,15 @@ import io.rsocket.util.EmptyPayload; import java.util.Collection; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicBoolean; +import org.assertj.core.api.Assertions; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; +import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; public class RSocketResponderTest { @@ -110,6 +115,58 @@ public Mono requestResponse(Payload payload) { assertThat("Subscription not cancelled.", cancelled.get(), is(true)); } + @Test + public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation() { + final int streamId = 4; + final AtomicBoolean cancelled = new AtomicBoolean(); + byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + final AbstractRSocket acceptingSocket = + new AbstractRSocket() { + @Override + public Mono requestResponse(Payload p) { + return Mono.just(payload).doOnCancel(() -> cancelled.set(true)); + } + + @Override + public Flux requestStream(Payload p) { + return Flux.just(payload).doOnCancel(() -> cancelled.set(true)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.just(payload).doOnCancel(() -> cancelled.set(true)); + } + }; + rule.setAcceptingSocket(acceptingSocket); + + final Runnable[] runnables = { + () -> rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE), + () -> rule.sendRequest(streamId, FrameType.REQUEST_STREAM), + () -> rule.sendRequest(streamId, FrameType.REQUEST_CHANNEL) + }; + + for (Runnable runnable : runnables) { + runnable.run(); + Assertions.assertThat(rule.errors) + .first() + .isInstanceOf(IllegalArgumentException.class) + .hasToString("java.lang.IllegalArgumentException: " + INVALID_PAYLOAD_ERROR_MESSAGE); + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderFlyweight.frameType(bb) == FrameType.ERROR) + .matches(bb -> ErrorFrameFlyweight.dataUtf8(bb).contains(INVALID_PAYLOAD_ERROR_MESSAGE)); + + assertThat("Subscription not cancelled.", cancelled.get(), is(true)); + rule.init(); + rule.setAcceptingSocket(acceptingSocket); + } + } + public static class ServerSocketRule extends AbstractSocketRule { private RSocket acceptingSocket; @@ -151,7 +208,8 @@ protected RSocketResponder newRSocket() { acceptingSocket, DefaultPayload::create, throwable -> errors.add(throwable), - ResponderLeaseHandler.None); + ResponderLeaseHandler.None, + 0); } private void sendRequest(int streamId, FrameType frameType) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java index b18fad890..edcc8971f 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java @@ -222,7 +222,8 @@ public Flux requestChannel(Publisher payloads) { requestAcceptor, DefaultPayload::create, throwable -> serverErrors.add(throwable), - ResponderLeaseHandler.None); + ResponderLeaseHandler.None, + 0); crs = new RSocketRequester( @@ -233,6 +234,7 @@ public Flux requestChannel(Publisher payloads) { StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, RequesterLeaseHandler.None); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java index daab5d246..9344d69da 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java @@ -58,6 +58,7 @@ void requesterStreamsTerminatedOnZeroErrorFrame() { StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, RequesterLeaseHandler.None); @@ -93,6 +94,7 @@ void requesterNewStreamsTerminatedAfterZeroErrorFrame() { StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, RequesterLeaseHandler.None);