diff --git a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketProperties.java b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketProperties.java index 232eb1db4272..cf18235260e4 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketProperties.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketProperties.java @@ -19,12 +19,15 @@ import java.net.InetAddress; import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.boot.rsocket.server.RSocketServer; +import org.springframework.boot.web.server.Ssl; /** * {@link ConfigurationProperties properties} for RSocket support. * * @author Brian Clozel + * @author Chris Bono * @since 2.2.0 */ @ConfigurationProperties("spring.rsocket") @@ -59,6 +62,9 @@ public static class Server { */ private String mappingPath; + @NestedConfigurationProperty + private Ssl ssl; + public Integer getPort() { return this.port; } @@ -91,6 +97,14 @@ public void setMappingPath(String mappingPath) { this.mappingPath = mappingPath; } + public Ssl getSsl() { + return this.ssl; + } + + public void setSsl(Ssl ssl) { + this.ssl = ssl; + } + } } diff --git a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfiguration.java b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfiguration.java index 8856f145213f..82d238319866 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfiguration.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfiguration.java @@ -97,6 +97,7 @@ RSocketServerFactory rSocketServerFactory(RSocketProperties properties, ReactorR PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull(); map.from(properties.getServer().getAddress()).to(factory::setAddress); map.from(properties.getServer().getPort()).to(factory::setPort); + map.from(properties.getServer().getSsl()).to(factory::setSsl); factory.setSocketFactoryProcessors(processors.orderedStream().collect(Collectors.toList())); return factory; } diff --git a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfigurationTests.java b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfigurationTests.java index f03c09c9fa96..1a8367756d6b 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfigurationTests.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfigurationTests.java @@ -93,6 +93,18 @@ void shouldSetLocalServerPortWhenRSocketServerPortIsSet() { }); } + @Test + void shouldUseSslWhenRocketServerSslIsConfigured() { + reactiveWebContextRunner() + .withPropertyValues("spring.rsocket.server.ssl.keyStore=classpath:rsocket/test.jks", + "spring.rsocket.server.ssl.keyPassword=password", "spring.rsocket.server.port=0") + .run((context) -> assertThat(context).hasSingleBean(RSocketServerFactory.class) + .hasSingleBean(RSocketServerBootstrap.class).hasSingleBean(ServerRSocketFactoryProcessor.class) + .getBean(RSocketServerFactory.class) + .hasFieldOrPropertyWithValue("ssl.keyStore", "classpath:rsocket/test.jks") + .hasFieldOrPropertyWithValue("ssl.keyPassword", "password")); + } + @Test void shouldUseCustomServerBootstrap() { contextRunner().withUserConfiguration(CustomServerBootstrapConfig.class).run((context) -> assertThat(context) diff --git a/spring-boot-project/spring-boot-autoconfigure/src/test/resources/rsocket/test.jks b/spring-boot-project/spring-boot-autoconfigure/src/test/resources/rsocket/test.jks new file mode 100644 index 000000000000..0fc3e802f754 Binary files /dev/null and b/spring-boot-project/spring-boot-autoconfigure/src/test/resources/rsocket/test.jks differ diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java index 9a2bbae9c8d2..8cab22c5d599 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java @@ -39,6 +39,9 @@ import org.springframework.boot.rsocket.server.RSocketServer; import org.springframework.boot.rsocket.server.RSocketServerFactory; import org.springframework.boot.rsocket.server.ServerRSocketFactoryProcessor; +import org.springframework.boot.web.embedded.netty.SslServerCustomizer; +import org.springframework.boot.web.server.Ssl; +import org.springframework.boot.web.server.SslStoreProvider; import org.springframework.http.client.reactive.ReactorResourceFactory; import org.springframework.util.Assert; @@ -47,6 +50,7 @@ * by Netty. * * @author Brian Clozel + * @author Chris Bono * @since 2.2.0 */ public class NettyRSocketServerFactory implements RSocketServerFactory, ConfigurableRSocketServerFactory { @@ -61,6 +65,10 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur private Duration lifecycleTimeout; + private Ssl ssl; + + private SslStoreProvider sslStoreProvider; + private List socketFactoryProcessors = new ArrayList<>(); @Override @@ -78,6 +86,16 @@ public void setTransport(RSocketServer.Transport transport) { this.transport = transport; } + @Override + public void setSsl(Ssl ssl) { + this.ssl = ssl; + } + + @Override + public void setSslStoreProvider(SslStoreProvider sslStoreProvider) { + this.sslStoreProvider = sslStoreProvider; + } + /** * Set the {@link ReactorResourceFactory} to get the shared resources from. * @param resourceFactory the server resources @@ -136,21 +154,41 @@ private ServerTransport createTransport() { } private ServerTransport createWebSocketTransport() { + HttpServer httpServer; if (this.resourceFactory != null) { - HttpServer httpServer = HttpServer.create().tcpConfiguration((tcpServer) -> tcpServer + httpServer = HttpServer.create().tcpConfiguration((tcpServer) -> tcpServer .runOn(this.resourceFactory.getLoopResources()).addressSupplier(this::getListenAddress)); - return WebsocketServerTransport.create(httpServer); } - return WebsocketServerTransport.create(getListenAddress()); + else { + InetSocketAddress listenAddress = this.getListenAddress(); + httpServer = HttpServer.create().host(listenAddress.getHostName()).port(listenAddress.getPort()); + } + + if (this.ssl != null && this.ssl.isEnabled()) { + SslServerCustomizer sslServerCustomizer = new SslServerCustomizer(this.ssl, null, this.sslStoreProvider); + httpServer = sslServerCustomizer.apply(httpServer); + } + + return WebsocketServerTransport.create(httpServer); } private ServerTransport createTcpTransport() { + TcpServer tcpServer; if (this.resourceFactory != null) { - TcpServer tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources()) + tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources()) .addressSupplier(this::getListenAddress); - return TcpServerTransport.create(tcpServer); } - return TcpServerTransport.create(getListenAddress()); + else { + InetSocketAddress listenAddress = this.getListenAddress(); + tcpServer = TcpServer.create().host(listenAddress.getHostName()).port(listenAddress.getPort()); + } + + if (this.ssl != null && this.ssl.isEnabled()) { + TcpSslServerCustomizer sslServerCustomizer = new TcpSslServerCustomizer(this.ssl, this.sslStoreProvider); + tcpServer = sslServerCustomizer.apply(tcpServer); + } + + return TcpServerTransport.create(tcpServer); } private InetSocketAddress getListenAddress() { @@ -160,4 +198,24 @@ private InetSocketAddress getListenAddress() { return new InetSocketAddress(this.port); } + private static final class TcpSslServerCustomizer extends SslServerCustomizer { + + private TcpSslServerCustomizer(Ssl ssl, SslStoreProvider sslStoreProvider) { + super(ssl, null, sslStoreProvider); + } + + // This does not override the apply in parent - currently just leveraging the + // parent for its "getContextBuilder()" method. This should be refactored when + // we add the concept of http/tcp customizers for RSocket. + private TcpServer apply(TcpServer server) { + try { + return server.secure((contextSpec) -> contextSpec.sslContext(getContextBuilder())); + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } + } + + } + } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/server/ConfigurableRSocketServerFactory.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/server/ConfigurableRSocketServerFactory.java index cb9741054293..afbf549ba2dd 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/server/ConfigurableRSocketServerFactory.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/server/ConfigurableRSocketServerFactory.java @@ -18,6 +18,9 @@ import java.net.InetAddress; +import org.springframework.boot.web.server.Ssl; +import org.springframework.boot.web.server.SslStoreProvider; + /** * A configurable {@link RSocketServerFactory}. * @@ -45,4 +48,16 @@ public interface ConfigurableRSocketServerFactory { */ void setTransport(RSocketServer.Transport transport); + /** + * Sets the SSL configuration that will be applied to the server's default connector. + * @param ssl the SSL configuration + */ + void setSsl(Ssl ssl); + + /** + * Sets a provider that will be used to obtain SSL stores. + * @param sslStoreProvider the SSL store provider + */ + void setSslStoreProvider(SslStoreProvider sslStoreProvider); + } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java index a8055999f467..43e72b75290e 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java @@ -17,16 +17,21 @@ package org.springframework.boot.rsocket.netty; import java.net.InetSocketAddress; +import java.nio.channels.ClosedChannelException; import java.time.Duration; import java.util.Arrays; import io.netty.buffer.PooledByteBufAllocator; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslProvider; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.rsocket.AbstractRSocket; import io.rsocket.ConnectionSetupPayload; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.RSocketFactory; import io.rsocket.SocketAcceptor; +import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.client.WebsocketClientTransport; import io.rsocket.util.DefaultPayload; import org.assertj.core.api.Assertions; @@ -34,9 +39,13 @@ import org.junit.jupiter.api.Test; import org.mockito.InOrder; import reactor.core.publisher.Mono; +import reactor.netty.tcp.TcpClient; +import reactor.test.StepVerifier; import org.springframework.boot.rsocket.server.RSocketServer; +import org.springframework.boot.rsocket.server.RSocketServer.Transport; import org.springframework.boot.rsocket.server.ServerRSocketFactoryProcessor; +import org.springframework.boot.web.server.Ssl; import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.core.codec.StringDecoder; import org.springframework.core.io.buffer.NettyDataBufferFactory; @@ -46,6 +55,8 @@ import org.springframework.util.SocketUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.inOrder; @@ -56,6 +67,7 @@ * * @author Brian Clozel * @author Leo Li + * @author Chris Bono */ class NettyRSocketServerFactoryTests { @@ -91,7 +103,7 @@ void specificPort() { factory.setPort(specificPort); this.server = factory.create(new EchoRequestResponseAcceptor()); this.server.start(); - this.requester = createRSocketTcpClient(); + this.requester = createRSocketTcpClient(false); String payload = "test payload"; String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT); assertThat(this.server.address().getPort()).isEqualTo(specificPort); @@ -104,7 +116,7 @@ void websocketTransport() { factory.setTransport(RSocketServer.Transport.WEBSOCKET); this.server = factory.create(new EchoRequestResponseAcceptor()); this.server.start(); - this.requester = createRSocketWebSocketClient(); + this.requester = createRSocketWebSocketClient(false); String payload = "test payload"; String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT); assertThat(response).isEqualTo(payload); @@ -121,7 +133,7 @@ void websocketTransportWithReactorResource() { factory.setPort(specificPort); this.server = factory.create(new EchoRequestResponseAcceptor()); this.server.start(); - this.requester = createRSocketWebSocketClient(); + this.requester = createRSocketWebSocketClient(false); String payload = "test payload"; String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT); assertThat(response).isEqualTo(payload); @@ -145,16 +157,94 @@ void serverProcessors() { } } - private RSocketRequester createRSocketTcpClient() { - Assertions.assertThat(this.server).isNotNull(); - InetSocketAddress address = this.server.address(); - return createRSocketRequesterBuilder().connectTcp(address.getHostString(), address.getPort()).block(); + @Test + void tcpTransportBasicSslFromClassPath() { + testBasicSslWithKeyStore("classpath:test.jks", "password", Transport.TCP); + } + + @Test + void tcpTransportBasicSslFromFileSystem() { + testBasicSslWithKeyStore("src/test/resources/test.jks", "password", Transport.TCP); + } + + @Test + void websocketTransportBasicSslFromClassPath() { + testBasicSslWithKeyStore("classpath:test.jks", "password", Transport.WEBSOCKET); + } + + @Test + void websocketTransportBasicSslFromFileSystem() { + testBasicSslWithKeyStore("src/test/resources/test.jks", "password", Transport.WEBSOCKET); + } + + private void testBasicSslWithKeyStore(String keyStore, String keyPassword, Transport transport) { + NettyRSocketServerFactory factory = getFactory(); + factory.setTransport(transport); + Ssl ssl = new Ssl(); + ssl.setKeyStore(keyStore); + ssl.setKeyPassword(keyPassword); + factory.setSsl(ssl); + this.server = factory.create(new EchoRequestResponseAcceptor()); + this.server.start(); + this.requester = (transport == Transport.TCP) ? createRSocketTcpClient(true) + : createRSocketWebSocketClient(true); + String payload = "test payload"; + Mono responseMono = this.requester.route("test").data(payload).retrieveMono(String.class); + StepVerifier.create(responseMono).expectNext(payload).verifyComplete(); } - private RSocketRequester createRSocketWebSocketClient() { + @Test + void tcpTransportSslRejectsInsecureClient() { + NettyRSocketServerFactory factory = getFactory(); + factory.setTransport(Transport.TCP); + Ssl ssl = new Ssl(); + ssl.setKeyStore("classpath:test.jks"); + ssl.setKeyPassword("password"); + factory.setSsl(ssl); + this.server = factory.create(new EchoRequestResponseAcceptor()); + this.server.start(); + this.requester = createRSocketTcpClient(false); + String payload = "test payload"; + Mono responseMono = this.requester.route("test").data(payload).retrieveMono(String.class); + StepVerifier.create(responseMono) + .verifyErrorSatisfies((ex) -> assertThatExceptionOfType(ClosedChannelException.class)); + } + + @Test + void websocketTransportSslRejectsInsecureClient() { + NettyRSocketServerFactory factory = getFactory(); + factory.setTransport(Transport.WEBSOCKET); + Ssl ssl = new Ssl(); + ssl.setKeyStore("classpath:test.jks"); + ssl.setKeyPassword("password"); + factory.setSsl(ssl); + this.server = factory.create(new EchoRequestResponseAcceptor()); + this.server.start(); + // For WebSocket, the SSL failure results in a hang on the initial connect call + assertThatThrownBy(() -> createRSocketWebSocketClient(false)).isInstanceOf(IllegalStateException.class) + .hasStackTraceContaining("Timeout on blocking read"); + } + + private RSocketRequester createRSocketTcpClient(boolean ssl) { + TcpClient tcpClient = createTcpClient(ssl); + return createRSocketRequesterBuilder().connect(TcpClientTransport.create(tcpClient)).block(TIMEOUT); + } + + private RSocketRequester createRSocketWebSocketClient(boolean ssl) { + TcpClient tcpClient = createTcpClient(ssl); + return createRSocketRequesterBuilder().connect(WebsocketClientTransport.create(tcpClient)).block(TIMEOUT); + } + + private TcpClient createTcpClient(boolean ssl) { Assertions.assertThat(this.server).isNotNull(); InetSocketAddress address = this.server.address(); - return createRSocketRequesterBuilder().connect(WebsocketClientTransport.create(address)).block(); + TcpClient tcpClient = TcpClient.create().host(address.getHostName()).port(address.getPort()); + if (ssl) { + SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK) + .trustManager(InsecureTrustManagerFactory.INSTANCE); + tcpClient = tcpClient.secure((spec) -> spec.sslContext(builder)); + } + return tcpClient; } private RSocketRequester.Builder createRSocketRequesterBuilder() {