Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ public Mono<RSocket> start() {
payloadDecoder,
errorConsumer,
StreamIdSupplier.clientSupplier(),
mtu,
keepAliveTickPeriod(),
keepAliveTimeout(),
keepAliveHandler,
Expand Down Expand Up @@ -379,7 +380,8 @@ public Mono<RSocket> start() {
wrappedRSocketHandler,
payloadDecoder,
errorConsumer,
responderLeaseHandler);
responderLeaseHandler,
mtu);

return wrappedConnection
.sendOne(setupFrame)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ private Mono<Void> acceptSetup(
payloadDecoder,
errorConsumer,
StreamIdSupplier.serverSupplier(),
mtu,
setupPayload.keepAliveInterval(),
setupPayload.keepAliveMaxLifetime(),
keepAliveHandler,
Expand Down Expand Up @@ -317,7 +318,8 @@ private Mono<Void> acceptSetup(
wrappedRSocketHandler,
payloadDecoder,
errorConsumer,
responderLeaseHandler);
responderLeaseHandler,
mtu);
})
.doFinally(signalType -> setupPayload.release())
.then();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package io.rsocket.core;

import io.rsocket.Payload;
import io.rsocket.frame.FrameHeaderFlyweight;
import io.rsocket.frame.FrameLengthFlyweight;

final class PayloadValidationUtils {
static final String INVALID_PAYLOAD_ERROR_MESSAGE =
"The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory.";

static boolean isValid(int mtu, Payload payload) {
if (mtu > 0) {
return true;
}

if (payload.hasMetadata()) {
return (((FrameHeaderFlyweight.size()
+ FrameLengthFlyweight.FRAME_LENGTH_SIZE
+ FrameHeaderFlyweight.size()
+ payload.data().readableBytes()
+ payload.metadata().readableBytes())
& ~FrameLengthFlyweight.FRAME_LENGTH_MASK)
== 0);
} else {
return (((FrameHeaderFlyweight.size()
+ payload.data().readableBytes()
+ FrameLengthFlyweight.FRAME_LENGTH_SIZE)
& ~FrameLengthFlyweight.FRAME_LENGTH_MASK)
== 0);
}
}
}
43 changes: 43 additions & 0 deletions rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.rsocket.core;

import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE;
import static io.rsocket.keepalive.KeepAliveSupport.ClientKeepAliveSupport;
import static io.rsocket.keepalive.KeepAliveSupport.KeepAlive;

Expand Down Expand Up @@ -88,6 +89,7 @@ class RSocketRequester implements RSocket {
private final IntObjectMap<Subscription> senders;
private final IntObjectMap<Processor<Payload, Payload>> receivers;
private final UnboundedProcessor<ByteBuf> sendProcessor;
private final int mtu;
private final RequesterLeaseHandler leaseHandler;
private final ByteBufAllocator allocator;
private final KeepAliveFramesAcceptor keepAliveFramesAcceptor;
Expand All @@ -99,6 +101,7 @@ class RSocketRequester implements RSocket {
PayloadDecoder payloadDecoder,
Consumer<Throwable> errorConsumer,
StreamIdSupplier streamIdSupplier,
int mtu,
int keepAliveTickPeriod,
int keepAliveAckTimeout,
@Nullable KeepAliveHandler keepAliveHandler,
Expand All @@ -108,6 +111,7 @@ class RSocketRequester implements RSocket {
this.payloadDecoder = payloadDecoder;
this.errorConsumer = errorConsumer;
this.streamIdSupplier = streamIdSupplier;
this.mtu = mtu;
this.leaseHandler = leaseHandler;
this.senders = new SynchronizedIntObjectHashMap<>();
this.receivers = new SynchronizedIntObjectHashMap<>();
Expand Down Expand Up @@ -186,6 +190,11 @@ private Mono<Void> handleFireAndForget(Payload payload) {
return Mono.error(err);
}

if (!PayloadValidationUtils.isValid(this.mtu, payload)) {
payload.release();
return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE));
}

final int streamId = streamIdSupplier.nextStreamId(receivers);

return UnicastMonoEmpty.newInstance(
Expand All @@ -210,6 +219,11 @@ private Mono<Payload> handleRequestResponse(final Payload payload) {
return Mono.error(err);
}

if (!PayloadValidationUtils.isValid(this.mtu, payload)) {
payload.release();
return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE));
}

int streamId = streamIdSupplier.nextStreamId(receivers);
final UnboundedProcessor<ByteBuf> sendProcessor = this.sendProcessor;

Expand Down Expand Up @@ -255,6 +269,11 @@ private Flux<Payload> handleRequestStream(final Payload payload) {
return Flux.error(err);
}

if (!PayloadValidationUtils.isValid(this.mtu, payload)) {
payload.release();
return Flux.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE));
}

int streamId = streamIdSupplier.nextStreamId(receivers);

final UnboundedProcessor<ByteBuf> sendProcessor = this.sendProcessor;
Expand Down Expand Up @@ -317,6 +336,13 @@ private Flux<Payload> handleChannel(Flux<Payload> request) {
(s, flux) -> {
Payload payload = s.get();
if (payload != null) {
if (!PayloadValidationUtils.isValid(mtu, payload)) {
payload.release();
final IllegalArgumentException t =
new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE);
errorConsumer.accept(t);
return Mono.error(t);
}
return handleChannel(payload, flux);
} else {
return flux;
Expand Down Expand Up @@ -348,6 +374,17 @@ protected void hookOnNext(Payload payload) {
first = false;
return;
}
if (!PayloadValidationUtils.isValid(mtu, payload)) {
payload.release();
cancel();
final IllegalArgumentException t =
new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE);
errorConsumer.accept(t);
// no need to send any errors.
sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId));
receiver.onError(t);
return;
}
final ByteBuf frame =
PayloadFrameFlyweight.encode(allocator, streamId, false, false, true, payload);

Expand Down Expand Up @@ -434,6 +471,11 @@ private Mono<Void> handleMetadataPush(Payload payload) {
return Mono.error(err);
}

if (!PayloadValidationUtils.isValid(this.mtu, payload)) {
payload.release();
return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE));
}

return UnicastMonoEmpty.newInstance(
() -> {
ByteBuf metadataPushFrame =
Expand All @@ -444,6 +486,7 @@ private Mono<Void> handleMetadataPush(Payload payload) {
});
}

@Nullable
private Throwable checkAvailable() {
Throwable err = this.terminationError;
if (err != null) {
Expand Down
26 changes: 25 additions & 1 deletion rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package io.rsocket.core;

import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.util.ReferenceCountUtil;
Expand Down Expand Up @@ -51,6 +53,8 @@ class RSocketResponder implements ResponderRSocket {
private final Consumer<Throwable> errorConsumer;
private final ResponderLeaseHandler leaseHandler;

private final int mtu;

private final IntObjectMap<Subscription> sendingSubscriptions;
private final IntObjectMap<Processor<Payload, Payload>> channelProcessors;

Expand All @@ -63,9 +67,11 @@ class RSocketResponder implements ResponderRSocket {
RSocket requestHandler,
PayloadDecoder payloadDecoder,
Consumer<Throwable> errorConsumer,
ResponderLeaseHandler leaseHandler) {
ResponderLeaseHandler leaseHandler,
int mtu) {
this.allocator = allocator;
this.connection = connection;
this.mtu = mtu;

this.requestHandler = requestHandler;
this.responderRSocket =
Expand Down Expand Up @@ -371,6 +377,15 @@ protected void hookOnNext(Payload payload) {
isEmpty = false;
}

if (!PayloadValidationUtils.isValid(mtu, payload)) {
payload.release();
cancel();
final IllegalArgumentException t =
new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE);
handleError(streamId, t);
return;
}

ByteBuf byteBuf;
try {
byteBuf = PayloadFrameFlyweight.encodeNextComplete(allocator, streamId, payload);
Expand Down Expand Up @@ -417,6 +432,15 @@ protected void hookOnSubscribe(Subscription s) {

@Override
protected void hookOnNext(Payload payload) {
if (!PayloadValidationUtils.isValid(mtu, payload)) {
payload.release();
cancel();
final IllegalArgumentException t =
new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE);
handleError(streamId, t);
return;
}

ByteBuf byteBuf;
try {
byteBuf = PayloadFrameFlyweight.encodeNext(allocator, streamId, payload);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.util.ReferenceCountUtil;
import io.rsocket.frame.*;
import io.rsocket.frame.FrameHeaderFlyweight;
import io.rsocket.frame.FrameLengthFlyweight;
import io.rsocket.frame.FrameType;
import io.rsocket.frame.PayloadFrameFlyweight;
import io.rsocket.frame.RequestChannelFrameFlyweight;
import io.rsocket.frame.RequestFireAndForgetFrameFlyweight;
import io.rsocket.frame.RequestResponseFrameFlyweight;
import io.rsocket.frame.RequestStreamFrameFlyweight;
import java.util.function.Consumer;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
Expand Down
2 changes: 2 additions & 0 deletions rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ static RSocketState requester(int tickPeriod, int timeout) {
DefaultPayload::create,
errors,
StreamIdSupplier.clientSupplier(),
0,
tickPeriod,
timeout,
new DefaultKeepAliveHandler(connection),
Expand All @@ -86,6 +87,7 @@ static ResumableRSocketState resumableRequester(int tickPeriod, int timeout) {
DefaultPayload::create,
errors,
StreamIdSupplier.clientSupplier(),
0,
tickPeriod,
timeout,
new ResumableKeepAliveHandler(resumableConnection),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package io.rsocket.core;

import io.rsocket.Payload;
import io.rsocket.frame.FrameHeaderFlyweight;
import io.rsocket.frame.FrameLengthFlyweight;
import io.rsocket.util.DefaultPayload;
import java.util.concurrent.ThreadLocalRandom;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

class PayloadValidationUtilsTest {

@Test
void shouldBeValidFrameWithNoFragmentation() {
byte[] data =
new byte
[FrameLengthFlyweight.FRAME_LENGTH_MASK
- FrameLengthFlyweight.FRAME_LENGTH_SIZE
- FrameHeaderFlyweight.size()];
ThreadLocalRandom.current().nextBytes(data);
final Payload payload = DefaultPayload.create(data);

Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isTrue();
}

@Test
void shouldBeInValidFrameWithNoFragmentation() {
byte[] data =
new byte
[FrameLengthFlyweight.FRAME_LENGTH_MASK
- FrameLengthFlyweight.FRAME_LENGTH_SIZE
- FrameHeaderFlyweight.size()
+ 1];
ThreadLocalRandom.current().nextBytes(data);
final Payload payload = DefaultPayload.create(data);

Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isFalse();
}

@Test
void shouldBeValidFrameWithNoFragmentation0() {
byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK / 2];
byte[] data =
new byte
[FrameLengthFlyweight.FRAME_LENGTH_MASK / 2
- FrameLengthFlyweight.FRAME_LENGTH_SIZE
- FrameHeaderFlyweight.size()
- FrameHeaderFlyweight.size()];
ThreadLocalRandom.current().nextBytes(data);
ThreadLocalRandom.current().nextBytes(metadata);
final Payload payload = DefaultPayload.create(data, metadata);

Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isTrue();
}

@Test
void shouldBeInValidFrameWithNoFragmentation1() {
byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK];
byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK];
ThreadLocalRandom.current().nextBytes(metadata);
ThreadLocalRandom.current().nextBytes(data);
final Payload payload = DefaultPayload.create(data, metadata);

Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isFalse();
}

@Test
void shouldBeValidFrameWithNoFragmentation2() {
byte[] metadata = new byte[1];
byte[] data = new byte[1];
ThreadLocalRandom.current().nextBytes(metadata);
ThreadLocalRandom.current().nextBytes(data);
final Payload payload = DefaultPayload.create(data, metadata);

Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isTrue();
}

@Test
void shouldBeValidFrameWithNoFragmentation3() {
byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK];
byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK];
ThreadLocalRandom.current().nextBytes(metadata);
ThreadLocalRandom.current().nextBytes(data);
final Payload payload = DefaultPayload.create(data, metadata);

Assertions.assertThat(PayloadValidationUtils.isValid(64, payload)).isTrue();
}

@Test
void shouldBeValidFrameWithNoFragmentation4() {
byte[] metadata = new byte[1];
byte[] data = new byte[1];
ThreadLocalRandom.current().nextBytes(metadata);
ThreadLocalRandom.current().nextBytes(data);
final Payload payload = DefaultPayload.create(data, metadata);

Assertions.assertThat(PayloadValidationUtils.isValid(64, payload)).isTrue();
}
}
Loading