Skip to content

Commit e4ad2b3

Browse files
committed
Add DestinationUserNameProvider interface
The interface is to be implemented in addition to java.security.Principal when Principal.getName() is not globally unique enough for use in user destinations. Issue: SPR-11327
1 parent 2cafe9d commit e4ad2b3

File tree

5 files changed

+117
-27
lines changed

5 files changed

+117
-27
lines changed

spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,9 @@ protected boolean checkDestination(String destination, String requiredPrefix) {
165165
return true;
166166
}
167167

168-
protected String getTargetDestination(String origDestination, String targetDestination, String sessionId, String user) {
168+
protected String getTargetDestination(String origDestination, String targetDestination,
169+
String sessionId, String user) {
170+
169171
return targetDestination + "-user" + sessionId;
170172
}
171173

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package org.springframework.messaging.simp.user;
2+
3+
/**
4+
* An interface to be implemented in addition to {@link java.security.Principal}
5+
* when {@link java.security.Principal#getName()} is not globally unique enough
6+
* for use in user destinations. For more on user destination see
7+
* {@link org.springframework.messaging.simp.user.UserDestinationResolver}.
8+
*
9+
* @author Rossen Stoyanchev
10+
* @since 4.0.1
11+
*/
12+
public interface DestinationUserNameProvider {
13+
14+
15+
/**
16+
* Return the (globally unique) user name to use with user destinations.
17+
*/
18+
String getDestinationUserName();
19+
20+
}

spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolverTests.java

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,21 @@ public class DefaultUserDestinationResolverTests {
3939

4040
private UserSessionRegistry registry;
4141

42+
private TestPrincipal user;
43+
4244

4345
@Before
4446
public void setup() {
47+
this.user = new TestPrincipal("joe");
4548
this.registry = new DefaultUserSessionRegistry();
46-
this.registry.registerSessionId("joe", SESSION_ID);
47-
49+
this.registry.registerSessionId(this.user.getName(), SESSION_ID);
4850
this.resolver = new DefaultUserDestinationResolver(this.registry);
4951
}
5052

5153

5254
@Test
5355
public void handleSubscribe() {
54-
Message<?> message = createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo");
56+
Message<?> message = createMessage(SimpMessageType.SUBSCRIBE, this.user, SESSION_ID, "/user/queue/foo");
5557
Set<String> actual = this.resolver.resolveDestination(message);
5658

5759
assertEquals(1, actual.size());
@@ -66,7 +68,7 @@ public void handleSubscribeOneUserMultipleSessions() {
6668
this.registry.registerSessionId("joe", "456");
6769
this.registry.registerSessionId("joe", "789");
6870

69-
Message<?> message = createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo");
71+
Message<?> message = createMessage(SimpMessageType.SUBSCRIBE, this.user, SESSION_ID, "/user/queue/foo");
7072
Set<String> actual = this.resolver.resolveDestination(message);
7173

7274
assertEquals(1, actual.size());
@@ -75,7 +77,7 @@ public void handleSubscribeOneUserMultipleSessions() {
7577

7678
@Test
7779
public void handleUnsubscribe() {
78-
Message<?> message = createMessage(SimpMessageType.UNSUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo");
80+
Message<?> message = createMessage(SimpMessageType.UNSUBSCRIBE, this.user, SESSION_ID, "/user/queue/foo");
7981
Set<String> actual = this.resolver.resolveDestination(message);
8082

8183
assertEquals(1, actual.size());
@@ -84,7 +86,7 @@ public void handleUnsubscribe() {
8486

8587
@Test
8688
public void handleMessage() {
87-
Message<?> message = createMessage(SimpMessageType.MESSAGE, "joe", SESSION_ID, "/user/joe/queue/foo");
89+
Message<?> message = createMessage(SimpMessageType.MESSAGE, this.user, SESSION_ID, "/user/joe/queue/foo");
8890
Set<String> actual = this.resolver.resolveDestination(message);
8991

9092
assertEquals(1, actual.size());
@@ -96,12 +98,12 @@ public void handleMessage() {
9698
public void ignoreMessage() {
9799

98100
// no destination
99-
Message<?> message = createMessage(SimpMessageType.MESSAGE, "joe", SESSION_ID, null);
101+
Message<?> message = createMessage(SimpMessageType.MESSAGE, this.user, SESSION_ID, null);
100102
Set<String> actual = this.resolver.resolveDestination(message);
101103
assertEquals(0, actual.size());
102104

103105
// not a user destination
104-
message = createMessage(SimpMessageType.MESSAGE, "joe", SESSION_ID, "/queue/foo");
106+
message = createMessage(SimpMessageType.MESSAGE, this.user, SESSION_ID, "/queue/foo");
105107
actual = this.resolver.resolveDestination(message);
106108
assertEquals(0, actual.size());
107109

@@ -111,24 +113,24 @@ public void ignoreMessage() {
111113
assertEquals(0, actual.size());
112114

113115
// subscribe + not a user destination
114-
message = createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/queue/foo");
116+
message = createMessage(SimpMessageType.SUBSCRIBE, this.user, SESSION_ID, "/queue/foo");
115117
actual = this.resolver.resolveDestination(message);
116118
assertEquals(0, actual.size());
117119

118120
// no match on message type
119-
message = createMessage(SimpMessageType.CONNECT, "joe", SESSION_ID, "user/joe/queue/foo");
121+
message = createMessage(SimpMessageType.CONNECT, this.user, SESSION_ID, "user/joe/queue/foo");
120122
actual = this.resolver.resolveDestination(message);
121123
assertEquals(0, actual.size());
122124
}
123125

124126

125-
private Message<?> createMessage(SimpMessageType messageType, String user, String sessionId, String destination) {
127+
private Message<?> createMessage(SimpMessageType messageType, TestPrincipal user, String sessionId, String destination) {
126128
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(messageType);
127129
if (destination != null) {
128130
headers.setDestination(destination);
129131
}
130132
if (user != null) {
131-
headers.setUser(new TestPrincipal(user));
133+
headers.setUser(user);
132134
}
133135
if (sessionId != null) {
134136
headers.setSessionId(sessionId);

spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.springframework.messaging.simp.stomp.StompDecoder;
3636
import org.springframework.messaging.simp.stomp.StompEncoder;
3737
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
38+
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
3839
import org.springframework.messaging.simp.user.UserSessionRegistry;
3940
import org.springframework.messaging.support.MessageBuilder;
4041
import org.springframework.util.Assert;
@@ -240,11 +241,20 @@ private void afterStompSessionConnected(StompHeaderAccessor headers, WebSocketSe
240241
if (principal != null) {
241242
headers.setNativeHeader(CONNECTED_USER_HEADER, principal.getName());
242243
if (this.userSessionRegistry != null) {
243-
this.userSessionRegistry.registerSessionId(principal.getName(), session.getId());
244+
String userName = getNameForUserSessionRegistry(principal);
245+
this.userSessionRegistry.registerSessionId(userName, session.getId());
244246
}
245247
}
246248
}
247249

250+
private String getNameForUserSessionRegistry(Principal principal) {
251+
String userName = principal.getName();
252+
if (principal instanceof DestinationUserNameProvider) {
253+
userName = ((DestinationUserNameProvider) principal).getDestinationUserName();
254+
}
255+
return userName;
256+
}
257+
248258
@Override
249259
public String resolveSessionId(Message<?> message) {
250260
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
@@ -258,8 +268,10 @@ public void afterSessionStarted(WebSocketSession session, MessageChannel outputC
258268
@Override
259269
public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus, MessageChannel outputChannel) {
260270

261-
if ((this.userSessionRegistry != null) && (session.getPrincipal() != null)) {
262-
this.userSessionRegistry.unregisterSessionId(session.getPrincipal().getName(), session.getId());
271+
Principal principal = session.getPrincipal();
272+
if ((this.userSessionRegistry != null) && (principal != null)) {
273+
String userName = getNameForUserSessionRegistry(principal);
274+
this.userSessionRegistry.unregisterSessionId(userName, session.getId());
263275
}
264276

265277
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);

spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2013 the original author or authors.
2+
* Copyright 2002-2014 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@
1818

1919
import java.nio.ByteBuffer;
2020
import java.util.Arrays;
21+
import java.util.Collections;
2122
import java.util.HashSet;
2223

2324
import org.junit.Before;
@@ -32,6 +33,9 @@
3233
import org.springframework.messaging.simp.stomp.StompCommand;
3334
import org.springframework.messaging.simp.stomp.StompDecoder;
3435
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
36+
import org.springframework.messaging.simp.user.DefaultUserSessionRegistry;
37+
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
38+
import org.springframework.messaging.simp.user.UserSessionRegistry;
3539
import org.springframework.messaging.support.MessageBuilder;
3640
import org.springframework.web.socket.TextMessage;
3741
import org.springframework.web.socket.WebSocketMessage;
@@ -47,7 +51,7 @@
4751
*/
4852
public class StompSubProtocolHandlerTests {
4953

50-
private StompSubProtocolHandler stompHandler;
54+
private StompSubProtocolHandler protocolHandler;
5155

5256
private TestWebSocketSession session;
5357

@@ -58,7 +62,7 @@ public class StompSubProtocolHandlerTests {
5862

5963
@Before
6064
public void setup() {
61-
this.stompHandler = new StompSubProtocolHandler();
65+
this.protocolHandler = new StompSubProtocolHandler();
6266
this.channel = Mockito.mock(MessageChannel.class);
6367
this.messageCaptor = ArgumentCaptor.forClass(Message.class);
6468

@@ -68,18 +72,55 @@ public void setup() {
6872
}
6973

7074
@Test
71-
public void connectedResponseIsSentWhenConnectAckIsToBeSentToClient() {
75+
public void handleMessageToClientConnected() {
76+
77+
UserSessionRegistry registry = new DefaultUserSessionRegistry();
78+
this.protocolHandler.setUserSessionRegistry(registry);
79+
80+
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
81+
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
82+
this.protocolHandler.handleMessageToClient(this.session, message);
83+
84+
assertEquals(1, this.session.getSentMessages().size());
85+
WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
86+
assertEquals("CONNECTED\n" + "user-name:joe\n" + "\n" + "\u0000", textMessage.getPayload());
87+
88+
assertEquals(Collections.singleton("s1"), registry.getSessionIds("joe"));
89+
}
90+
91+
@Test
92+
public void handleMessageToClientConnectedUniqueUserName() {
93+
94+
this.session.setPrincipal(new UniqueUser("joe"));
95+
96+
UserSessionRegistry registry = new DefaultUserSessionRegistry();
97+
this.protocolHandler.setUserSessionRegistry(registry);
98+
99+
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
100+
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
101+
this.protocolHandler.handleMessageToClient(this.session, message);
102+
103+
assertEquals(1, this.session.getSentMessages().size());
104+
WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
105+
assertEquals("CONNECTED\n" + "user-name:joe\n" + "\n" + "\u0000", textMessage.getPayload());
106+
107+
assertEquals(Collections.<String>emptySet(), registry.getSessionIds("joe"));
108+
assertEquals(Collections.singleton("s1"), registry.getSessionIds("Me myself and I"));
109+
}
110+
111+
@Test
112+
public void handleMessageToClientConnectAck() {
113+
72114
StompHeaderAccessor connectHeaders = StompHeaderAccessor.create(StompCommand.CONNECT);
73115
connectHeaders.setHeartbeat(10000, 10000);
74116
connectHeaders.setNativeHeader(StompHeaderAccessor.STOMP_ACCEPT_VERSION_HEADER, "1.0,1.1");
75-
76117
Message<?> connectMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(connectHeaders).build();
77118

78119
SimpMessageHeaderAccessor connectAckHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
79120
connectAckHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, connectMessage);
121+
Message<byte[]> connectAckMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(connectAckHeaders).build();
80122

81-
Message<byte[]> connectAck = MessageBuilder.withPayload(new byte[0]).setHeaders(connectAckHeaders).build();
82-
this.stompHandler.handleMessageToClient(this.session, connectAck);
123+
this.protocolHandler.handleMessageToClient(this.session, connectAckMessage);
83124

84125
verifyNoMoreInteractions(this.channel);
85126

@@ -97,12 +138,12 @@ public void connectedResponseIsSentWhenConnectAckIsToBeSentToClient() {
97138
}
98139

99140
@Test
100-
public void messagesAreAugmentedAndForwarded() {
141+
public void handleMessageFromClient() {
101142

102143
TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).headers(
103144
"login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build();
104145

105-
this.stompHandler.handleMessageFromClient(this.session, textMessage, this.channel);
146+
this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);
106147

107148
verify(this.channel).send(this.messageCaptor.capture());
108149
Message<?> actual = this.messageCaptor.getValue();
@@ -121,16 +162,29 @@ public void messagesAreAugmentedAndForwarded() {
121162
}
122163

123164
@Test
124-
public void invalidStompCommand() {
165+
public void handleMessageFromClientInvalidStompCommand() {
125166

126167
TextMessage textMessage = new TextMessage("FOO");
127168

128-
this.stompHandler.handleMessageFromClient(this.session, textMessage, this.channel);
169+
this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);
129170

130171
verifyZeroInteractions(this.channel);
131172
assertEquals(1, this.session.getSentMessages().size());
132173
TextMessage actual = (TextMessage) this.session.getSentMessages().get(0);
133174
assertTrue(actual.getPayload().startsWith("ERROR"));
134175
}
135176

177+
178+
private static class UniqueUser extends TestPrincipal implements DestinationUserNameProvider {
179+
180+
private UniqueUser(String name) {
181+
super(name);
182+
}
183+
184+
@Override
185+
public String getDestinationUserName() {
186+
return "Me myself and I";
187+
}
188+
}
189+
136190
}

0 commit comments

Comments
 (0)