Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
import java.net.InetAddress;

import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import org.springframework.boot.rsocket.server.RSocketServer;
import org.springframework.boot.web.server.Ssl;

/**
* {@link ConfigurationProperties properties} for RSocket support.
*
* @author Brian Clozel
* @author Chris Bono
* @since 2.2.0
*/
@ConfigurationProperties("spring.rsocket")
Expand Down Expand Up @@ -59,6 +62,9 @@ public static class Server {
*/
private String mappingPath;

@NestedConfigurationProperty
private Ssl ssl;

public Integer getPort() {
return this.port;
}
Expand Down Expand Up @@ -91,6 +97,14 @@ public void setMappingPath(String mappingPath) {
this.mappingPath = mappingPath;
}

public Ssl getSsl() {
return this.ssl;
}

public void setSsl(Ssl ssl) {
this.ssl = ssl;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ RSocketServerFactory rSocketServerFactory(RSocketProperties properties, ReactorR
PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull();
map.from(properties.getServer().getAddress()).to(factory::setAddress);
map.from(properties.getServer().getPort()).to(factory::setPort);
map.from(properties.getServer().getSsl()).to(factory::setSsl);
factory.setSocketFactoryProcessors(processors.orderedStream().collect(Collectors.toList()));
return factory;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ void shouldSetLocalServerPortWhenRSocketServerPortIsSet() {
});
}

@Test
void shouldUseSslWhenRocketServerSslIsConfigured() {
reactiveWebContextRunner()
.withPropertyValues("spring.rsocket.server.ssl.keyStore=classpath:rsocket/test.jks",
"spring.rsocket.server.ssl.keyPassword=password", "spring.rsocket.server.port=0")
.run((context) -> assertThat(context).hasSingleBean(RSocketServerFactory.class)
.hasSingleBean(RSocketServerBootstrap.class).hasSingleBean(ServerRSocketFactoryProcessor.class)
.getBean(RSocketServerFactory.class)
.hasFieldOrPropertyWithValue("ssl.keyStore", "classpath:rsocket/test.jks")
.hasFieldOrPropertyWithValue("ssl.keyPassword", "password"));
}

@Test
void shouldUseCustomServerBootstrap() {
contextRunner().withUserConfiguration(CustomServerBootstrapConfig.class).run((context) -> assertThat(context)
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
import org.springframework.boot.rsocket.server.RSocketServer;
import org.springframework.boot.rsocket.server.RSocketServerFactory;
import org.springframework.boot.rsocket.server.ServerRSocketFactoryProcessor;
import org.springframework.boot.web.embedded.netty.SslServerCustomizer;
import org.springframework.boot.web.server.Ssl;
import org.springframework.boot.web.server.SslStoreProvider;
import org.springframework.http.client.reactive.ReactorResourceFactory;
import org.springframework.util.Assert;

Expand All @@ -47,6 +50,7 @@
* by Netty.
*
* @author Brian Clozel
* @author Chris Bono
* @since 2.2.0
*/
public class NettyRSocketServerFactory implements RSocketServerFactory, ConfigurableRSocketServerFactory {
Expand All @@ -61,6 +65,10 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur

private Duration lifecycleTimeout;

private Ssl ssl;

private SslStoreProvider sslStoreProvider;

private List<ServerRSocketFactoryProcessor> socketFactoryProcessors = new ArrayList<>();

@Override
Expand All @@ -78,6 +86,16 @@ public void setTransport(RSocketServer.Transport transport) {
this.transport = transport;
}

@Override
public void setSsl(Ssl ssl) {
this.ssl = ssl;
}

@Override
public void setSslStoreProvider(SslStoreProvider sslStoreProvider) {
this.sslStoreProvider = sslStoreProvider;
}

/**
* Set the {@link ReactorResourceFactory} to get the shared resources from.
* @param resourceFactory the server resources
Expand Down Expand Up @@ -136,21 +154,41 @@ private ServerTransport<CloseableChannel> createTransport() {
}

private ServerTransport<CloseableChannel> createWebSocketTransport() {
HttpServer httpServer;
if (this.resourceFactory != null) {
HttpServer httpServer = HttpServer.create().tcpConfiguration((tcpServer) -> tcpServer
httpServer = HttpServer.create().tcpConfiguration((tcpServer) -> tcpServer
.runOn(this.resourceFactory.getLoopResources()).addressSupplier(this::getListenAddress));
return WebsocketServerTransport.create(httpServer);
}
return WebsocketServerTransport.create(getListenAddress());
else {
InetSocketAddress listenAddress = this.getListenAddress();
httpServer = HttpServer.create().host(listenAddress.getHostName()).port(listenAddress.getPort());
}

if (this.ssl != null && this.ssl.isEnabled()) {
SslServerCustomizer sslServerCustomizer = new SslServerCustomizer(this.ssl, null, this.sslStoreProvider);
httpServer = sslServerCustomizer.apply(httpServer);
}

return WebsocketServerTransport.create(httpServer);
}

private ServerTransport<CloseableChannel> createTcpTransport() {
TcpServer tcpServer;
if (this.resourceFactory != null) {
TcpServer tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources())
tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources())
.addressSupplier(this::getListenAddress);
return TcpServerTransport.create(tcpServer);
}
return TcpServerTransport.create(getListenAddress());
else {
InetSocketAddress listenAddress = this.getListenAddress();
tcpServer = TcpServer.create().host(listenAddress.getHostName()).port(listenAddress.getPort());
}

if (this.ssl != null && this.ssl.isEnabled()) {
TcpSslServerCustomizer sslServerCustomizer = new TcpSslServerCustomizer(this.ssl, this.sslStoreProvider);
tcpServer = sslServerCustomizer.apply(tcpServer);
}

return TcpServerTransport.create(tcpServer);
}

private InetSocketAddress getListenAddress() {
Expand All @@ -160,4 +198,24 @@ private InetSocketAddress getListenAddress() {
return new InetSocketAddress(this.port);
}

private static final class TcpSslServerCustomizer extends SslServerCustomizer {

private TcpSslServerCustomizer(Ssl ssl, SslStoreProvider sslStoreProvider) {
super(ssl, null, sslStoreProvider);
}

// This does not override the apply in parent - currently just leveraging the
// parent for its "getContextBuilder()" method. This should be refactored when
// we add the concept of http/tcp customizers for RSocket.
private TcpServer apply(TcpServer server) {
try {
return server.secure((contextSpec) -> contextSpec.sslContext(getContextBuilder()));
}
catch (Exception ex) {
throw new IllegalStateException(ex);
}
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

import java.net.InetAddress;

import org.springframework.boot.web.server.Ssl;
import org.springframework.boot.web.server.SslStoreProvider;

/**
* A configurable {@link RSocketServerFactory}.
*
Expand Down Expand Up @@ -45,4 +48,16 @@ public interface ConfigurableRSocketServerFactory {
*/
void setTransport(RSocketServer.Transport transport);

/**
* Sets the SSL configuration that will be applied to the server's default connector.
* @param ssl the SSL configuration
*/
void setSsl(Ssl ssl);

/**
* Sets a provider that will be used to obtain SSL stores.
* @param sslStoreProvider the SSL store provider
*/
void setSslStoreProvider(SslStoreProvider sslStoreProvider);

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,35 @@
package org.springframework.boot.rsocket.netty;

import java.net.InetSocketAddress;
import java.nio.channels.ClosedChannelException;
import java.time.Duration;
import java.util.Arrays;

import io.netty.buffer.PooledByteBufAllocator;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslProvider;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.rsocket.AbstractRSocket;
import io.rsocket.ConnectionSetupPayload;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.RSocketFactory;
import io.rsocket.SocketAcceptor;
import io.rsocket.transport.netty.client.TcpClientTransport;
import io.rsocket.transport.netty.client.WebsocketClientTransport;
import io.rsocket.util.DefaultPayload;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.mockito.InOrder;
import reactor.core.publisher.Mono;
import reactor.netty.tcp.TcpClient;
import reactor.test.StepVerifier;

import org.springframework.boot.rsocket.server.RSocketServer;
import org.springframework.boot.rsocket.server.RSocketServer.Transport;
import org.springframework.boot.rsocket.server.ServerRSocketFactoryProcessor;
import org.springframework.boot.web.server.Ssl;
import org.springframework.core.codec.CharSequenceEncoder;
import org.springframework.core.codec.StringDecoder;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
Expand All @@ -46,6 +55,8 @@
import org.springframework.util.SocketUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.inOrder;
Expand All @@ -56,6 +67,7 @@
*
* @author Brian Clozel
* @author Leo Li
* @author Chris Bono
*/
class NettyRSocketServerFactoryTests {

Expand Down Expand Up @@ -91,7 +103,7 @@ void specificPort() {
factory.setPort(specificPort);
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
this.requester = createRSocketTcpClient();
this.requester = createRSocketTcpClient(false);
String payload = "test payload";
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
assertThat(this.server.address().getPort()).isEqualTo(specificPort);
Expand All @@ -104,7 +116,7 @@ void websocketTransport() {
factory.setTransport(RSocketServer.Transport.WEBSOCKET);
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
this.requester = createRSocketWebSocketClient();
this.requester = createRSocketWebSocketClient(false);
String payload = "test payload";
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
assertThat(response).isEqualTo(payload);
Expand All @@ -121,7 +133,7 @@ void websocketTransportWithReactorResource() {
factory.setPort(specificPort);
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
this.requester = createRSocketWebSocketClient();
this.requester = createRSocketWebSocketClient(false);
String payload = "test payload";
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
assertThat(response).isEqualTo(payload);
Expand All @@ -145,16 +157,94 @@ void serverProcessors() {
}
}

private RSocketRequester createRSocketTcpClient() {
Assertions.assertThat(this.server).isNotNull();
InetSocketAddress address = this.server.address();
return createRSocketRequesterBuilder().connectTcp(address.getHostString(), address.getPort()).block();
@Test
void tcpTransportBasicSslFromClassPath() {
testBasicSslWithKeyStore("classpath:test.jks", "password", Transport.TCP);
}

@Test
void tcpTransportBasicSslFromFileSystem() {
testBasicSslWithKeyStore("src/test/resources/test.jks", "password", Transport.TCP);
}

@Test
void websocketTransportBasicSslFromClassPath() {
testBasicSslWithKeyStore("classpath:test.jks", "password", Transport.WEBSOCKET);
}

@Test
void websocketTransportBasicSslFromFileSystem() {
testBasicSslWithKeyStore("src/test/resources/test.jks", "password", Transport.WEBSOCKET);
}

private void testBasicSslWithKeyStore(String keyStore, String keyPassword, Transport transport) {
NettyRSocketServerFactory factory = getFactory();
factory.setTransport(transport);
Ssl ssl = new Ssl();
ssl.setKeyStore(keyStore);
ssl.setKeyPassword(keyPassword);
factory.setSsl(ssl);
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
this.requester = (transport == Transport.TCP) ? createRSocketTcpClient(true)
: createRSocketWebSocketClient(true);
String payload = "test payload";
Mono<String> responseMono = this.requester.route("test").data(payload).retrieveMono(String.class);
StepVerifier.create(responseMono).expectNext(payload).verifyComplete();
}

private RSocketRequester createRSocketWebSocketClient() {
@Test
void tcpTransportSslRejectsInsecureClient() {
NettyRSocketServerFactory factory = getFactory();
factory.setTransport(Transport.TCP);
Ssl ssl = new Ssl();
ssl.setKeyStore("classpath:test.jks");
ssl.setKeyPassword("password");
factory.setSsl(ssl);
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
this.requester = createRSocketTcpClient(false);
String payload = "test payload";
Mono<String> responseMono = this.requester.route("test").data(payload).retrieveMono(String.class);
StepVerifier.create(responseMono)
.verifyErrorSatisfies((ex) -> assertThatExceptionOfType(ClosedChannelException.class));
}

@Test
void websocketTransportSslRejectsInsecureClient() {
NettyRSocketServerFactory factory = getFactory();
factory.setTransport(Transport.WEBSOCKET);
Ssl ssl = new Ssl();
ssl.setKeyStore("classpath:test.jks");
ssl.setKeyPassword("password");
factory.setSsl(ssl);
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
// For WebSocket, the SSL failure results in a hang on the initial connect call
assertThatThrownBy(() -> createRSocketWebSocketClient(false)).isInstanceOf(IllegalStateException.class)
.hasStackTraceContaining("Timeout on blocking read");
}

private RSocketRequester createRSocketTcpClient(boolean ssl) {
TcpClient tcpClient = createTcpClient(ssl);
return createRSocketRequesterBuilder().connect(TcpClientTransport.create(tcpClient)).block(TIMEOUT);
}

private RSocketRequester createRSocketWebSocketClient(boolean ssl) {
TcpClient tcpClient = createTcpClient(ssl);
return createRSocketRequesterBuilder().connect(WebsocketClientTransport.create(tcpClient)).block(TIMEOUT);
}

private TcpClient createTcpClient(boolean ssl) {
Assertions.assertThat(this.server).isNotNull();
InetSocketAddress address = this.server.address();
return createRSocketRequesterBuilder().connect(WebsocketClientTransport.create(address)).block();
TcpClient tcpClient = TcpClient.create().host(address.getHostName()).port(address.getPort());
if (ssl) {
SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK)
.trustManager(InsecureTrustManagerFactory.INSTANCE);
tcpClient = tcpClient.secure((spec) -> spec.sslContext(builder));
}
return tcpClient;
}

private RSocketRequester.Builder createRSocketRequesterBuilder() {
Expand Down