1616
1717package org .springframework .web .socket .server .standard ;
1818
19+ import java .lang .reflect .Constructor ;
1920import java .util .Arrays ;
2021import java .util .Collections ;
2122import java .util .List ;
23+ import java .util .Set ;
24+ import java .util .concurrent .ConcurrentHashMap ;
2225import javax .servlet .http .HttpServletRequest ;
2326import javax .servlet .http .HttpServletResponse ;
2427import javax .websocket .Decoder ;
3437import io .undertow .websockets .core .WebSocketChannel ;
3538import io .undertow .websockets .core .WebSocketVersion ;
3639import io .undertow .websockets .core .protocol .Handshake ;
37- import io .undertow .websockets .core .protocol .version07 .Hybi07Handshake ;
38- import io .undertow .websockets .core .protocol .version08 .Hybi08Handshake ;
39- import io .undertow .websockets .core .protocol .version13 .Hybi13Handshake ;
4040import io .undertow .websockets .jsr .ConfiguredServerEndpoint ;
4141import io .undertow .websockets .jsr .EncodingFactory ;
4242import io .undertow .websockets .jsr .EndpointSessionHandler ;
4545import io .undertow .websockets .jsr .handshake .JsrHybi07Handshake ;
4646import io .undertow .websockets .jsr .handshake .JsrHybi08Handshake ;
4747import io .undertow .websockets .jsr .handshake .JsrHybi13Handshake ;
48+ import org .springframework .util .ClassUtils ;
4849import org .xnio .StreamConnection ;
4950
5051import org .springframework .http .server .ServerHttpRequest ;
6162 */
6263public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy {
6364
64- private final String [] supportedVersions = new String [] {
65+ private static final Constructor <ServletWebSocketHttpExchange > exchangeConstructor ;
66+
67+ private static final boolean undertow10Present ;
68+
69+ static {
70+ Class <ServletWebSocketHttpExchange > type = ServletWebSocketHttpExchange .class ;
71+ Class <?>[] paramTypes = new Class <?>[] {HttpServletRequest .class , HttpServletResponse .class , Set .class };
72+ if (ClassUtils .hasConstructor (type , paramTypes )) {
73+ exchangeConstructor = ClassUtils .getConstructorIfAvailable (type , paramTypes );
74+ undertow10Present = false ;
75+ }
76+ else {
77+ paramTypes = new Class <?>[] {HttpServletRequest .class , HttpServletResponse .class };
78+ exchangeConstructor = ClassUtils .getConstructorIfAvailable (type , paramTypes );
79+ undertow10Present = true ;
80+ }
81+ }
82+
83+ private static final String [] supportedVersions = new String [] {
6584 WebSocketVersion .V13 .toHttpHeaderValue (),
6685 WebSocketVersion .V08 .toHttpHeaderValue (),
6786 WebSocketVersion .V07 .toHttpHeaderValue ()
6887 };
6988
7089
90+ private Set <WebSocketChannel > peerConnections ;
91+
92+
93+ public UndertowRequestUpgradeStrategy () {
94+ if (undertow10Present ) {
95+ this .peerConnections = null ;
96+ }
97+ else {
98+ this .peerConnections = Collections .newSetFromMap (new ConcurrentHashMap <WebSocketChannel , Boolean >());
99+ }
100+ }
101+
102+
71103 @ Override
72104 public String [] getSupportedVersions () {
73- return this . supportedVersions ;
105+ return supportedVersions ;
74106 }
75107
76108 @ Override
@@ -80,7 +112,7 @@ protected void upgradeInternal(ServerHttpRequest request, ServerHttpResponse res
80112 HttpServletRequest servletRequest = getHttpServletRequest (request );
81113 HttpServletResponse servletResponse = getHttpServletResponse (response );
82114
83- final ServletWebSocketHttpExchange exchange = new ServletWebSocketHttpExchange (servletRequest , servletResponse );
115+ final ServletWebSocketHttpExchange exchange = createHttpExchange (servletRequest , servletResponse );
84116 exchange .putAttachment (HandshakeUtil .PATH_PARAMS , Collections .<String , String >emptyMap ());
85117
86118 ServerWebSocketContainer wsContainer = (ServerWebSocketContainer ) getContainer (servletRequest );
@@ -95,13 +127,27 @@ protected void upgradeInternal(ServerHttpRequest request, ServerHttpResponse res
95127 @ Override
96128 public void handleUpgrade (StreamConnection connection , HttpServerExchange serverExchange ) {
97129 WebSocketChannel channel = handshake .createChannel (exchange , connection , exchange .getBufferPool ());
130+ if (peerConnections != null ) {
131+ peerConnections .add (channel );
132+ }
98133 endpointSessionHandler .onConnect (exchange , channel );
99134 }
100135 });
101136
102137 handshake .handshake (exchange );
103138 }
104139
140+ private ServletWebSocketHttpExchange createHttpExchange (HttpServletRequest request , HttpServletResponse response ) {
141+ try {
142+ return (this .peerConnections != null ?
143+ exchangeConstructor .newInstance (request , response , this .peerConnections ) :
144+ exchangeConstructor .newInstance (request , response ));
145+ }
146+ catch (Exception ex ) {
147+ throw new HandshakeFailureException ("Failed to instantiate ServletWebSocketHttpExchange" , ex );
148+ }
149+ }
150+
105151 private Handshake getHandshakeToUse (ServletWebSocketHttpExchange exchange , ConfiguredServerEndpoint endpoint ) {
106152 Handshake handshake = new JsrHybi13Handshake (endpoint );
107153 if (handshake .matches (exchange )) {
0 commit comments