1717package org .springframework .web .socket .server .support ;
1818
1919import java .io .IOException ;
20+ import java .lang .reflect .Method ;
2021import java .util .Arrays ;
2122import java .util .Collections ;
2223import java .util .Map ;
2627import javax .servlet .http .HttpServletRequest ;
2728import javax .servlet .http .HttpServletResponse ;
2829import javax .websocket .Endpoint ;
30+ import javax .websocket .server .ServerEndpointConfig ;
2931
32+ import org .apache .tomcat .websocket .server .WsHandshakeRequest ;
33+ import org .apache .tomcat .websocket .server .WsHttpUpgradeHandler ;
3034import org .apache .tomcat .websocket .server .WsServerContainer ;
3135import org .springframework .http .server .ServerHttpRequest ;
3236import org .springframework .http .server .ServerHttpResponse ;
3337import org .springframework .http .server .ServletServerHttpRequest ;
3438import org .springframework .http .server .ServletServerHttpResponse ;
3539import org .springframework .util .Assert ;
40+ import org .springframework .util .ReflectionUtils ;
3641import org .springframework .web .socket .server .HandshakeFailureException ;
3742import org .springframework .web .socket .server .endpoint .ServerEndpointRegistration ;
3843
@@ -60,6 +65,18 @@ public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse respon
6065 Assert .isTrue (response instanceof ServletServerHttpResponse );
6166 HttpServletResponse servletResponse = ((ServletServerHttpResponse ) response ).getServletResponse ();
6267
68+ if (hasDoUpgrade ) {
69+ doUpgrade (servletRequest , servletResponse , acceptedProtocol , endpoint );
70+ }
71+ else {
72+ upgradeTomcat80RC1 (servletRequest , acceptedProtocol , endpoint );
73+ }
74+ }
75+
76+ private void doUpgrade (HttpServletRequest servletRequest , HttpServletResponse servletResponse ,
77+ String acceptedProtocol , Endpoint endpoint ) {
78+
79+ StringBuffer requestUrl = servletRequest .getRequestURL ();
6380 String path = servletRequest .getRequestURI (); // shouldn't matter
6481 Map <String , String > pathParams = Collections .<String , String > emptyMap ();
6582
@@ -71,11 +88,11 @@ public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse respon
7188 }
7289 catch (ServletException ex ) {
7390 throw new HandshakeFailureException (
74- "Servlet request failed to upgrade to WebSocket, uri=" + request . getURI () , ex );
91+ "Servlet request failed to upgrade to WebSocket, uri=" + requestUrl , ex );
7592 }
7693 catch (IOException ex ) {
7794 throw new HandshakeFailureException (
78- "Response update failed during upgrade to WebSocket, uri=" + request . getURI () , ex );
95+ "Response update failed during upgrade to WebSocket, uri=" + requestUrl , ex );
7996 }
8097 }
8198
@@ -85,4 +102,36 @@ private WsServerContainer getContainer(HttpServletRequest servletRequest) {
85102 return (WsServerContainer ) servletContext .getAttribute (attribute );
86103 }
87104
88- }
105+ // FIXME: Remove this after RC2 is out
106+
107+ private void upgradeTomcat80RC1 (HttpServletRequest request , String protocol , Endpoint endpoint ) {
108+
109+ WsHttpUpgradeHandler upgradeHandler ;
110+ try {
111+ upgradeHandler = request .upgrade (WsHttpUpgradeHandler .class );
112+ }
113+ catch (Exception e ) {
114+ throw new HandshakeFailureException ("Unable to create UpgardeHandler" , e );
115+ }
116+
117+ WsHandshakeRequest webSocketRequest = new WsHandshakeRequest (request );
118+ try {
119+ Method method = ReflectionUtils .findMethod (WsHandshakeRequest .class , "finished" );
120+ ReflectionUtils .makeAccessible (method );
121+ method .invoke (webSocketRequest );
122+ }
123+ catch (Exception ex ) {
124+ throw new HandshakeFailureException ("Failed to upgrade HttpServletRequest" , ex );
125+ }
126+
127+ ServerEndpointConfig endpointConfig = new ServerEndpointRegistration ("/shouldntmatter" , endpoint );
128+
129+ upgradeHandler .preInit (endpoint , endpointConfig , getContainer (request ), webSocketRequest ,
130+ protocol , Collections .<String , String > emptyMap (), request .isSecure ());
131+ }
132+
133+ private static boolean hasDoUpgrade = (ReflectionUtils .findMethod (WsServerContainer .class ,
134+ "doUpgrade" , HttpServletRequest .class , HttpServletResponse .class ,
135+ ServerEndpointConfig .class , Map .class ) != null );
136+
137+ }
0 commit comments