@@ -105,6 +105,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
105105
106106 private MessageHeaderInitializer headerInitializer ;
107107
108+ private final Map <String , Principal > stompAuthentications = new ConcurrentHashMap <String , Principal >();
109+
108110 private Boolean immutableMessageInterceptorPresent ;
109111
110112 private ApplicationEventPublisher eventPublisher ;
@@ -272,11 +274,10 @@ else if (webSocketMessage instanceof BinaryMessage) {
272274 try {
273275 StompHeaderAccessor headerAccessor =
274276 MessageHeaderAccessor .getAccessor (message , StompHeaderAccessor .class );
275- Principal user = session .getPrincipal ();
276277
277278 headerAccessor .setSessionId (session .getId ());
278279 headerAccessor .setSessionAttributes (session .getAttributes ());
279- headerAccessor .setUser (user );
280+ headerAccessor .setUser (getUser ( session ) );
280281 headerAccessor .setHeader (SimpMessageHeaderAccessor .HEART_BEAT_HEADER , headerAccessor .getHeartbeat ());
281282 if (!detectImmutableMessageInterceptor (outputChannel )) {
282283 headerAccessor .setImmutable ();
@@ -286,7 +287,8 @@ else if (webSocketMessage instanceof BinaryMessage) {
286287 logger .trace ("From client: " + headerAccessor .getShortLogMessage (message .getPayload ()));
287288 }
288289
289- if (StompCommand .CONNECT .equals (headerAccessor .getCommand ())) {
290+ boolean isConnect = StompCommand .CONNECT .equals (headerAccessor .getCommand ());
291+ if (isConnect ) {
290292 this .stats .incrementConnectCount ();
291293 }
292294 else if (StompCommand .DISCONNECT .equals (headerAccessor .getCommand ())) {
@@ -297,15 +299,23 @@ else if (StompCommand.DISCONNECT.equals(headerAccessor.getCommand())) {
297299 SimpAttributesContextHolder .setAttributesFromMessage (message );
298300 boolean sent = outputChannel .send (message );
299301
300- if (sent && this . eventPublisher != null ) {
301- if (StompCommand . CONNECT . equals ( headerAccessor . getCommand ()) ) {
302- publishEvent ( new SessionConnectEvent ( this , message , user ) );
303- }
304- else if ( StompCommand . SUBSCRIBE . equals ( headerAccessor . getCommand ())) {
305- publishEvent ( new SessionSubscribeEvent ( this , message , user ));
302+ if (sent ) {
303+ if (isConnect ) {
304+ Principal user = headerAccessor . getUser ( );
305+ if ( user != null && user != session . getPrincipal ()) {
306+ this . stompAuthentications . put ( session . getId (), user );
307+ }
306308 }
307- else if (StompCommand .UNSUBSCRIBE .equals (headerAccessor .getCommand ())) {
308- publishEvent (new SessionUnsubscribeEvent (this , message , user ));
309+ if (this .eventPublisher != null ) {
310+ if (isConnect ) {
311+ publishEvent (new SessionConnectEvent (this , message , getUser (session )));
312+ }
313+ else if (StompCommand .SUBSCRIBE .equals (headerAccessor .getCommand ())) {
314+ publishEvent (new SessionSubscribeEvent (this , message , getUser (session )));
315+ }
316+ else if (StompCommand .UNSUBSCRIBE .equals (headerAccessor .getCommand ())) {
317+ publishEvent (new SessionUnsubscribeEvent (this , message , getUser (session )));
318+ }
309319 }
310320 }
311321 }
@@ -323,6 +333,11 @@ else if (StompCommand.UNSUBSCRIBE.equals(headerAccessor.getCommand())) {
323333 }
324334 }
325335
336+ private Principal getUser (WebSocketSession session ) {
337+ Principal user = this .stompAuthentications .get (session .getId ());
338+ return user != null ? user : session .getPrincipal ();
339+ }
340+
326341 @ SuppressWarnings ("deprecation" )
327342 private void handleError (WebSocketSession session , Throwable ex , Message <byte []> clientMessage ) {
328343 if (getErrorHandler () == null ) {
@@ -425,7 +440,7 @@ else if (StompCommand.CONNECTED.equals(command)) {
425440 try {
426441 SimpAttributes simpAttributes = new SimpAttributes (session .getId (), session .getAttributes ());
427442 SimpAttributesContextHolder .setAttributes (simpAttributes );
428- Principal user = session . getPrincipal ( );
443+ Principal user = getUser ( session );
429444 publishEvent (new SessionConnectedEvent (this , (Message <byte []>) message , user ));
430445 }
431446 finally {
@@ -566,7 +581,7 @@ protected StompHeaderAccessor toMutableAccessor(StompHeaderAccessor headerAccess
566581 private StompHeaderAccessor afterStompSessionConnected (Message <?> message , StompHeaderAccessor accessor ,
567582 WebSocketSession session ) {
568583
569- Principal principal = session . getPrincipal ( );
584+ Principal principal = getUser ( session );
570585 if (principal != null ) {
571586 accessor = toMutableAccessor (accessor , message );
572587 accessor .setNativeHeader (CONNECTED_USER_HEADER , principal .getName ());
@@ -613,7 +628,7 @@ public void afterSessionStarted(WebSocketSession session, MessageChannel outputC
613628 public void afterSessionEnded (WebSocketSession session , CloseStatus closeStatus , MessageChannel outputChannel ) {
614629 this .decoders .remove (session .getId ());
615630
616- Principal principal = session . getPrincipal ( );
631+ Principal principal = getUser ( session );
617632 if (principal != null && this .userSessionRegistry != null ) {
618633 String userName = getSessionRegistryUserName (principal );
619634 this .userSessionRegistry .unregisterSessionId (userName , session .getId ());
@@ -624,12 +639,13 @@ public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus,
624639 try {
625640 SimpAttributesContextHolder .setAttributes (simpAttributes );
626641 if (this .eventPublisher != null ) {
627- Principal user = session . getPrincipal ( );
642+ Principal user = getUser ( session );
628643 publishEvent (new SessionDisconnectEvent (this , message , session .getId (), closeStatus , user ));
629644 }
630645 outputChannel .send (message );
631646 }
632647 finally {
648+ this .stompAuthentications .remove (session .getId ());
633649 SimpAttributesContextHolder .resetAttributes ();
634650 simpAttributes .sessionCompleted ();
635651 }
@@ -642,7 +658,7 @@ private Message<byte[]> createDisconnectMessage(WebSocketSession session) {
642658 }
643659 headerAccessor .setSessionId (session .getId ());
644660 headerAccessor .setSessionAttributes (session .getAttributes ());
645- headerAccessor .setUser (session . getPrincipal ( ));
661+ headerAccessor .setUser (getUser ( session ));
646662 return MessageBuilder .createMessage (EMPTY_PAYLOAD , headerAccessor .getMessageHeaders ());
647663 }
648664
0 commit comments