Skip to content

Commit 0715750

Browse files
committed
Polish "Add SSL support to RSocketServer"
See gh-19399
1 parent b4810b8 commit 0715750

File tree

4 files changed

+62
-78
lines changed

4 files changed

+62
-78
lines changed

spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketProperties.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2012-2019 the original author or authors.
2+
* Copyright 2012-2020 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.

spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -151,41 +151,27 @@ private ServerTransport<CloseableChannel> createTransport() {
151151
}
152152

153153
private ServerTransport<CloseableChannel> createWebSocketTransport() {
154-
HttpServer httpServer;
154+
HttpServer httpServer = HttpServer.create();
155155
if (this.resourceFactory != null) {
156-
httpServer = HttpServer.create().runOn(this.resourceFactory.getLoopResources())
157-
.bindAddress(this::getListenAddress);
156+
httpServer = httpServer.runOn(this.resourceFactory.getLoopResources());
158157
}
159-
else {
160-
InetSocketAddress listenAddress = this.getListenAddress();
161-
httpServer = HttpServer.create().host(listenAddress.getHostName()).port(listenAddress.getPort());
162-
}
163-
164158
if (this.ssl != null && this.ssl.isEnabled()) {
165159
SslServerCustomizer sslServerCustomizer = new SslServerCustomizer(this.ssl, null, this.sslStoreProvider);
166160
httpServer = sslServerCustomizer.apply(httpServer);
167161
}
168-
169-
return WebsocketServerTransport.create(httpServer);
162+
return WebsocketServerTransport.create(httpServer.bindAddress(this::getListenAddress));
170163
}
171164

172165
private ServerTransport<CloseableChannel> createTcpTransport() {
173-
TcpServer tcpServer;
166+
TcpServer tcpServer = TcpServer.create();
174167
if (this.resourceFactory != null) {
175-
tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources())
176-
.bindAddress(this::getListenAddress);
168+
tcpServer = tcpServer.runOn(this.resourceFactory.getLoopResources());
177169
}
178-
else {
179-
InetSocketAddress listenAddress = this.getListenAddress();
180-
tcpServer = TcpServer.create().host(listenAddress.getHostName()).port(listenAddress.getPort());
181-
}
182-
183170
if (this.ssl != null && this.ssl.isEnabled()) {
184171
TcpSslServerCustomizer sslServerCustomizer = new TcpSslServerCustomizer(this.ssl, this.sslStoreProvider);
185172
tcpServer = sslServerCustomizer.apply(tcpServer);
186173
}
187-
188-
return TcpServerTransport.create(tcpServer);
174+
return TcpServerTransport.create(tcpServer.bindAddress(this::getListenAddress));
189175
}
190176

191177
private InetSocketAddress getListenAddress() {
@@ -201,9 +187,6 @@ private TcpSslServerCustomizer(Ssl ssl, SslStoreProvider sslStoreProvider) {
201187
super(ssl, null, sslStoreProvider);
202188
}
203189

204-
// This does not override the apply in parent - currently just leveraging the
205-
// parent for its "getContextBuilder()" method. This should be refactored when
206-
// we add the concept of http/tcp customizers for RSocket.
207190
private TcpServer apply(TcpServer server) {
208191
try {
209192
return server.secure((contextSpec) -> contextSpec.sslContext(getContextBuilder()));

spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/server/ConfigurableRSocketServerFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2012-2019 the original author or authors.
2+
* Copyright 2012-2020 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.

spring-boot-project/spring-boot/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java

Lines changed: 54 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import java.net.InetSocketAddress;
2020
import java.nio.channels.ClosedChannelException;
21-
import java.time.Duration;
2221
import java.util.Arrays;
2322
import java.util.concurrent.Callable;
2423

@@ -38,12 +37,13 @@
3837
import org.junit.jupiter.api.Test;
3938
import org.mockito.InOrder;
4039
import reactor.core.publisher.Mono;
40+
import reactor.netty.http.client.HttpClient;
4141
import reactor.netty.tcp.TcpClient;
4242
import reactor.test.StepVerifier;
4343

4444
import org.springframework.boot.rsocket.server.RSocketServer;
45-
import org.springframework.boot.rsocket.server.RSocketServerCustomizer;
4645
import org.springframework.boot.rsocket.server.RSocketServer.Transport;
46+
import org.springframework.boot.rsocket.server.RSocketServerCustomizer;
4747
import org.springframework.boot.web.server.Ssl;
4848
import org.springframework.core.codec.CharSequenceEncoder;
4949
import org.springframework.core.codec.StringDecoder;
@@ -55,7 +55,6 @@
5555

5656
import static org.assertj.core.api.Assertions.assertThat;
5757
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
58-
import static org.assertj.core.api.Assertions.assertThatThrownBy;
5958
import static org.mockito.ArgumentMatchers.any;
6059
import static org.mockito.BDDMockito.will;
6160
import static org.mockito.Mockito.inOrder;
@@ -74,10 +73,11 @@ class NettyRSocketServerFactoryTests {
7473

7574
private RSocketRequester requester;
7675

77-
private static final Duration TIMEOUT = Duration.ofSeconds(3);
78-
7976
@AfterEach
8077
void tearDown() {
78+
if (this.requester != null) {
79+
this.requester.rsocketClient().dispose();
80+
}
8181
if (this.server != null) {
8282
try {
8383
this.server.stop();
@@ -86,9 +86,6 @@ void tearDown() {
8686
// Ignore
8787
}
8888
}
89-
if (this.requester != null) {
90-
this.requester.rsocketClient().dispose();
91-
}
9289
}
9390

9491
private NettyRSocketServerFactory getFactory() {
@@ -105,11 +102,9 @@ void specificPort() {
105102
this.server.start();
106103
return port;
107104
});
108-
this.requester = createRSocketTcpClient(false);
109-
String payload = "test payload";
110-
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
105+
this.requester = createRSocketTcpClient();
111106
assertThat(this.server.address().getPort()).isEqualTo(specificPort);
112-
assertThat(response).isEqualTo(payload);
107+
checkEchoRequest();
113108
}
114109

115110
@Test
@@ -118,10 +113,8 @@ void websocketTransport() {
118113
factory.setTransport(RSocketServer.Transport.WEBSOCKET);
119114
this.server = factory.create(new EchoRequestResponseAcceptor());
120115
this.server.start();
121-
this.requester = createRSocketWebSocketClient(false);
122-
String payload = "test payload";
123-
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
124-
assertThat(response).isEqualTo(payload);
116+
this.requester = createRSocketWebSocketClient();
117+
checkEchoRequest();
125118
}
126119

127120
@Test
@@ -133,10 +126,8 @@ void websocketTransportWithReactorResource() {
133126
factory.setResourceFactory(resourceFactory);
134127
this.server = factory.create(new EchoRequestResponseAcceptor());
135128
this.server.start();
136-
this.requester = createRSocketWebSocketClient(false);
137-
String payload = "test payload";
138-
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
139-
assertThat(response).isEqualTo(payload);
129+
this.requester = createRSocketWebSocketClient();
130+
checkEchoRequest();
140131
}
141132

142133
@Test
@@ -176,6 +167,12 @@ void websocketTransportBasicSslFromFileSystem() {
176167
testBasicSslWithKeyStore("src/test/resources/test.jks", "password", Transport.WEBSOCKET);
177168
}
178169

170+
private void checkEchoRequest() {
171+
String payload = "test payload";
172+
Mono<String> response = this.requester.route("test").data(payload).retrieveMono(String.class);
173+
StepVerifier.create(response).expectNext(payload).verifyComplete();
174+
}
175+
179176
private void testBasicSslWithKeyStore(String keyStore, String keyPassword, Transport transport) {
180177
NettyRSocketServerFactory factory = getFactory();
181178
factory.setTransport(transport);
@@ -185,11 +182,9 @@ private void testBasicSslWithKeyStore(String keyStore, String keyPassword, Trans
185182
factory.setSsl(ssl);
186183
this.server = factory.create(new EchoRequestResponseAcceptor());
187184
this.server.start();
188-
this.requester = (transport == Transport.TCP) ? createRSocketTcpClient(true)
189-
: createRSocketWebSocketClient(true);
190-
String payload = "test payload";
191-
Mono<String> responseMono = this.requester.route("test").data(payload).retrieveMono(String.class);
192-
StepVerifier.create(responseMono).expectNext(payload).verifyComplete();
185+
this.requester = (transport == Transport.TCP) ? createSecureRSocketTcpClient()
186+
: createSecureRSocketWebSocketClient();
187+
checkEchoRequest();
193188
}
194189

195190
@Test
@@ -202,48 +197,54 @@ void tcpTransportSslRejectsInsecureClient() {
202197
factory.setSsl(ssl);
203198
this.server = factory.create(new EchoRequestResponseAcceptor());
204199
this.server.start();
205-
this.requester = createRSocketTcpClient(false);
200+
this.requester = createRSocketTcpClient();
206201
String payload = "test payload";
207202
Mono<String> responseMono = this.requester.route("test").data(payload).retrieveMono(String.class);
208203
StepVerifier.create(responseMono)
209204
.verifyErrorSatisfies((ex) -> assertThatExceptionOfType(ClosedChannelException.class));
210205
}
211206

212-
@Test
213-
void websocketTransportSslRejectsInsecureClient() {
214-
NettyRSocketServerFactory factory = getFactory();
215-
factory.setTransport(Transport.WEBSOCKET);
216-
Ssl ssl = new Ssl();
217-
ssl.setKeyStore("classpath:test.jks");
218-
ssl.setKeyPassword("password");
219-
factory.setSsl(ssl);
220-
this.server = factory.create(new EchoRequestResponseAcceptor());
221-
this.server.start();
222-
// For WebSocket, the SSL failure results in a hang on the initial connect call
223-
assertThatThrownBy(() -> createRSocketWebSocketClient(false)).isInstanceOf(IllegalStateException.class)
224-
.hasStackTraceContaining("Timeout on blocking read");
207+
private RSocketRequester createRSocketTcpClient() {
208+
return createRSocketRequesterBuilder().transport(TcpClientTransport.create(createTcpClient()));
209+
}
210+
211+
private RSocketRequester createRSocketWebSocketClient() {
212+
return createRSocketRequesterBuilder().transport(WebsocketClientTransport.create(createHttpClient(), "/"));
225213
}
226214

227-
private RSocketRequester createRSocketTcpClient(boolean ssl) {
228-
TcpClient tcpClient = createTcpClient(ssl);
229-
return createRSocketRequesterBuilder().connect(TcpClientTransport.create(tcpClient)).block(TIMEOUT);
215+
private RSocketRequester createSecureRSocketTcpClient() {
216+
return createRSocketRequesterBuilder().transport(TcpClientTransport.create(createSecureTcpClient()));
230217
}
231218

232-
private RSocketRequester createRSocketWebSocketClient(boolean ssl) {
233-
TcpClient tcpClient = createTcpClient(ssl);
234-
return createRSocketRequesterBuilder().connect(WebsocketClientTransport.create(tcpClient)).block(TIMEOUT);
219+
private RSocketRequester createSecureRSocketWebSocketClient() {
220+
return createRSocketRequesterBuilder()
221+
.transport(WebsocketClientTransport.create(createSecureHttpClient(), "/"));
235222
}
236223

237-
private TcpClient createTcpClient(boolean ssl) {
224+
private HttpClient createSecureHttpClient() {
225+
HttpClient httpClient = createHttpClient();
226+
SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK)
227+
.trustManager(InsecureTrustManagerFactory.INSTANCE);
228+
return httpClient.secure((spec) -> spec.sslContext(builder));
229+
}
230+
231+
private HttpClient createHttpClient() {
238232
Assertions.assertThat(this.server).isNotNull();
239233
InetSocketAddress address = this.server.address();
240-
TcpClient tcpClient = TcpClient.create().host(address.getHostName()).port(address.getPort());
241-
if (ssl) {
242-
SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK)
243-
.trustManager(InsecureTrustManagerFactory.INSTANCE);
244-
tcpClient = tcpClient.secure((spec) -> spec.sslContext(builder));
245-
}
246-
return tcpClient;
234+
return HttpClient.create().host(address.getHostName()).port(address.getPort());
235+
}
236+
237+
private TcpClient createSecureTcpClient() {
238+
TcpClient tcpClient = createTcpClient();
239+
SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK)
240+
.trustManager(InsecureTrustManagerFactory.INSTANCE);
241+
return tcpClient.secure((spec) -> spec.sslContext(builder));
242+
}
243+
244+
private TcpClient createTcpClient() {
245+
Assertions.assertThat(this.server).isNotNull();
246+
InetSocketAddress address = this.server.address();
247+
return TcpClient.create().host(address.getHostName()).port(address.getPort());
247248
}
248249

249250
private RSocketRequester.Builder createRSocketRequesterBuilder() {

0 commit comments

Comments
 (0)