Skip to content

Commit 4dd5c27

Browse files
committed
Add missing HandshakeInterceptor for STOMP endpoints
Issue: SPR-11845
1 parent 6d6cc0e commit 4dd5c27

File tree

3 files changed

+48
-11
lines changed

3 files changed

+48
-11
lines changed

spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.web.socket.config.annotation;
1818

1919
import org.springframework.web.socket.server.HandshakeHandler;
20+
import org.springframework.web.socket.server.HandshakeInterceptor;
2021

2122
/**
2223
* A contract for configuring a STOMP over WebSocket endpoint.
@@ -36,4 +37,9 @@ public interface StompWebSocketEndpointRegistration {
3637
*/
3738
StompWebSocketEndpointRegistration setHandshakeHandler(HandshakeHandler handshakeHandler);
3839

40+
/**
41+
* Configure the HandshakeInterceptor's to use.
42+
*/
43+
StompWebSocketEndpointRegistration addInterceptors(HandshakeInterceptor... interceptors);
44+
3945
}

spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,14 @@
2323
import org.springframework.web.HttpRequestHandler;
2424
import org.springframework.web.socket.WebSocketHandler;
2525
import org.springframework.web.socket.server.HandshakeHandler;
26+
import org.springframework.web.socket.server.HandshakeInterceptor;
2627
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
2728
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
2829
import org.springframework.web.socket.sockjs.SockJsService;
2930
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
3031

32+
import java.util.Arrays;
33+
3134
/**
3235
* An abstract base class class for configuring STOMP over WebSocket/SockJS endpoints.
3336
*
@@ -44,6 +47,8 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE
4447

4548
private HandshakeHandler handshakeHandler;
4649

50+
private HandshakeInterceptor[] interceptors;
51+
4752
private StompSockJsServiceRegistration registration;
4853

4954

@@ -58,22 +63,29 @@ public WebMvcStompWebSocketEndpointRegistration(String[] paths, WebSocketHandler
5863
this.sockJsTaskScheduler = sockJsTaskScheduler;
5964
}
6065

61-
/**
62-
* Provide a custom or pre-configured {@link HandshakeHandler}.
63-
*/
6466
@Override
6567
public StompWebSocketEndpointRegistration setHandshakeHandler(HandshakeHandler handshakeHandler) {
6668
Assert.notNull(handshakeHandler, "'handshakeHandler' must not be null");
6769
this.handshakeHandler = handshakeHandler;
6870
return this;
6971
}
7072

71-
/**
72-
* Enable SockJS fallback options.
73-
*/
73+
@Override
74+
public StompWebSocketEndpointRegistration addInterceptors(HandshakeInterceptor... interceptors) {
75+
this.interceptors = interceptors;
76+
return this;
77+
}
78+
79+
protected HandshakeInterceptor[] getInterceptors() {
80+
return this.interceptors;
81+
}
82+
7483
@Override
7584
public SockJsServiceRegistration withSockJS() {
7685
this.registration = new StompSockJsServiceRegistration(this.sockJsTaskScheduler);
86+
if (this.interceptors != null) {
87+
this.registration.setInterceptors(this.interceptors);
88+
}
7789
if (this.handshakeHandler != null) {
7890
WebSocketTransportHandler transportHandler = new WebSocketTransportHandler(this.handshakeHandler);
7991
this.registration.setTransportHandlerOverrides(transportHandler);
@@ -93,9 +105,16 @@ public final MultiValueMap<HttpRequestHandler, String> getMappings() {
93105
}
94106
else {
95107
for (String path : this.paths) {
96-
WebSocketHttpRequestHandler handler = (this.handshakeHandler != null) ?
97-
new WebSocketHttpRequestHandler(this.webSocketHandler, this.handshakeHandler) :
98-
new WebSocketHttpRequestHandler(this.webSocketHandler);
108+
WebSocketHttpRequestHandler handler;
109+
if (this.handshakeHandler != null) {
110+
handler = new WebSocketHttpRequestHandler(this.webSocketHandler, this.handshakeHandler);
111+
}
112+
else {
113+
handler = new WebSocketHttpRequestHandler(this.webSocketHandler);
114+
}
115+
if (this.interceptors != null) {
116+
handler.setHandshakeInterceptors(Arrays.asList(this.interceptors));
117+
}
99118
mappings.add(handler, path);
100119
}
101120
}

spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929
import org.springframework.util.MultiValueMap;
3030
import org.springframework.web.HttpRequestHandler;
3131
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
32+
import org.springframework.web.socket.server.HandshakeInterceptor;
3233
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
34+
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
3335
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
3436
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
3537
import org.springframework.web.socket.sockjs.transport.TransportHandler;
@@ -38,6 +40,8 @@
3840
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
3941

4042
import static org.junit.Assert.*;
43+
import static org.junit.Assert.assertArrayEquals;
44+
import static org.junit.Assert.assertEquals;
4145
import static org.mockito.Mockito.mock;
4246

4347
/**
@@ -73,12 +77,15 @@ public void minimalRegistration() {
7377
}
7478

7579
@Test
76-
public void customHandshakeHandler() {
80+
public void handshakeHandlerAndInterceptors() {
7781
WebMvcStompWebSocketEndpointRegistration registration =
7882
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
7983

8084
DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
85+
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
86+
8187
registration.setHandshakeHandler(handshakeHandler);
88+
registration.addInterceptors(interceptor);
8289

8390
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
8491
assertEquals(1, mappings.size());
@@ -89,15 +96,19 @@ public void customHandshakeHandler() {
8996
WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey();
9097
assertNotNull(requestHandler.getWebSocketHandler());
9198
assertSame(handshakeHandler, requestHandler.getHandshakeHandler());
99+
assertEquals(Arrays.asList(interceptor), requestHandler.getHandshakeInterceptors());
92100
}
93101

94102
@Test
95-
public void customHandshakeHandlerPassedToSockJsService() {
103+
public void handshakeHandlerAndInterceptorsWithSockJsService() {
96104
WebMvcStompWebSocketEndpointRegistration registration =
97105
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
98106

99107
DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
108+
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
109+
100110
registration.setHandshakeHandler(handshakeHandler);
111+
registration.addInterceptors(interceptor);
101112
registration.withSockJS();
102113

103114
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
@@ -115,6 +126,7 @@ public void customHandshakeHandlerPassedToSockJsService() {
115126
Map<TransportType, TransportHandler> handlers = sockJsService.getTransportHandlers();
116127
WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET);
117128
assertSame(handshakeHandler, transportHandler.getHandshakeHandler());
129+
assertEquals(Arrays.asList(interceptor), sockJsService.getHandshakeInterceptors());
118130
}
119131

120132
}

0 commit comments

Comments
 (0)