Skip to content

Commit 3694a62

Browse files
committed
Initial WebSocket compression implementation
1 parent 4d97221 commit 3694a62

File tree

2 files changed

+213
-1
lines changed

2 files changed

+213
-1
lines changed

src/Middleware/WebSockets/src/WebSocketMiddleware.cs

Lines changed: 178 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33

44
using System;
55
using System.Collections.Generic;
6+
using System.Diagnostics;
7+
using System.Globalization;
68
using System.IO;
79
using System.Linq;
10+
using System.Net.Http.Headers;
811
using System.Net.WebSockets;
12+
using System.Text;
913
using System.Threading.Tasks;
1014
using Microsoft.AspNetCore.Builder;
1115
using Microsoft.AspNetCore.Http;
@@ -151,9 +155,46 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
151155

152156
HandshakeHelpers.GenerateResponseHeaders(key, subProtocol, _context.Response.Headers);
153157

158+
// TODO: get from options
159+
WebSocketDeflateOptions? deflateOptions = null;
160+
var ext = _context.Request.Headers["Sec-WebSocket-Extensions"];
161+
if (ext.Count != 0)
162+
{
163+
var decline = false;
164+
foreach (var extension in ext)
165+
{
166+
if (extension.TrimStart().StartsWith(ClientWebSocketDeflateConstants.Extension, StringComparison.Ordinal))
167+
{
168+
deflateOptions = new();
169+
if (ParseDeflateOptions(extension, deflateOptions, out var hasClientMaxWindowBits))
170+
{
171+
Resp(_context.Response.Headers, deflateOptions, hasClientMaxWindowBits);
172+
decline = false;
173+
break;
174+
}
175+
else
176+
{
177+
decline = true;
178+
}
179+
}
180+
}
181+
if (decline)
182+
{
183+
throw new InvalidOperationException("'permessage-deflate' extension not accepted.");
184+
}
185+
}
186+
154187
Stream opaqueTransport = await _upgradeFeature.UpgradeAsync(); // Sets status code to 101
155188

156-
return WebSocket.CreateFromStream(opaqueTransport, isServer: true, subProtocol: subProtocol, keepAliveInterval: keepAliveInterval);
189+
var options = new WebSocketCreationOptions()
190+
{
191+
IsServer = true,
192+
KeepAliveInterval = keepAliveInterval,
193+
SubProtocol = subProtocol,
194+
DangerousDeflateOptions = deflateOptions,
195+
};
196+
197+
return WebSocket.CreateFromStream(opaqueTransport, options);
157198
}
158199

159200
public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders)
@@ -226,6 +267,142 @@ public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictiona
226267

227268
return HandshakeHelpers.IsRequestKeyValid(requestHeaders.SecWebSocketKey.ToString());
228269
}
270+
271+
internal static class ClientWebSocketDeflateConstants
272+
{
273+
/// <summary>
274+
/// The maximum length that this extension can have, assuming that we're not abusing white space.
275+
/// <para />
276+
/// "permessage-deflate; client_max_window_bits=15; client_no_context_takeover; server_max_window_bits=15; server_no_context_takeover"
277+
/// </summary>
278+
public const int MaxExtensionLength = 128;
279+
280+
public const string Extension = "permessage-deflate";
281+
282+
public const string ClientMaxWindowBits = "client_max_window_bits";
283+
public const string ClientNoContextTakeover = "client_no_context_takeover";
284+
285+
public const string ServerMaxWindowBits = "server_max_window_bits";
286+
public const string ServerNoContextTakeover = "server_no_context_takeover";
287+
}
288+
289+
private static bool ParseDeflateOptions(ReadOnlySpan<char> extension, WebSocketDeflateOptions options, out bool hasClientMaxWindowBits)
290+
{
291+
hasClientMaxWindowBits = false;
292+
while (true)
293+
{
294+
int end = extension.IndexOf(';');
295+
ReadOnlySpan<char> value = (end >= 0 ? extension[..end] : extension).Trim();
296+
297+
if (value.Length > 0)
298+
{
299+
if (value.SequenceEqual(ClientWebSocketDeflateConstants.ClientNoContextTakeover))
300+
{
301+
options.ClientContextTakeover = false;
302+
}
303+
else if (value.SequenceEqual(ClientWebSocketDeflateConstants.ServerNoContextTakeover))
304+
{
305+
options.ServerContextTakeover = false;
306+
}
307+
else if (value.StartsWith(ClientWebSocketDeflateConstants.ClientMaxWindowBits))
308+
{
309+
hasClientMaxWindowBits = true;
310+
var clientMaxWindowBits = ParseWindowBits(value);
311+
if (clientMaxWindowBits > options.ClientMaxWindowBits)
312+
{
313+
return false;
314+
}
315+
options.ClientMaxWindowBits = clientMaxWindowBits;
316+
}
317+
else if (value.StartsWith(ClientWebSocketDeflateConstants.ServerMaxWindowBits))
318+
{
319+
var serverMaxWindowBits = ParseWindowBits(value);
320+
if (serverMaxWindowBits > options.ServerMaxWindowBits)
321+
{
322+
return false;
323+
}
324+
options.ServerMaxWindowBits = serverMaxWindowBits;
325+
}
326+
327+
static int ParseWindowBits(ReadOnlySpan<char> value)
328+
{
329+
// parameters can be sent without a value by the client
330+
var startIndex = value.IndexOf('=');
331+
332+
if (startIndex < 0 ||
333+
!int.TryParse(value[(startIndex + 1)..], NumberStyles.Integer, CultureInfo.InvariantCulture, out int windowBits) ||
334+
windowBits < 9 ||
335+
windowBits > 15)
336+
{
337+
throw new WebSocketException(WebSocketError.HeaderError, "");
338+
}
339+
340+
return windowBits;
341+
}
342+
}
343+
344+
if (end < 0)
345+
{
346+
break;
347+
}
348+
extension = extension[(end + 1)..];
349+
}
350+
351+
return true;
352+
}
353+
354+
private static void Resp(IHeaderDictionary headers, WebSocketDeflateOptions options, bool hasClientMaxWindowBits)
355+
{
356+
headers.Add("Sec-WebSocket-Extensions", GetDeflateOptions(options, hasClientMaxWindowBits));
357+
358+
static string GetDeflateOptions(WebSocketDeflateOptions options, bool hasClientMaxWindowBits)
359+
{
360+
var builder = new StringBuilder(ClientWebSocketDeflateConstants.MaxExtensionLength);
361+
builder.Append(ClientWebSocketDeflateConstants.Extension);
362+
363+
// If a received extension negotiation offer doesn't have the
364+
// "client_max_window_bits" extension parameter, the corresponding
365+
// extension negotiation response to the offer MUST NOT include the
366+
// "client_max_window_bits" extension parameter.
367+
// https://tools.ietf.org/html/rfc7692#section-7.1.2.2
368+
if (hasClientMaxWindowBits)
369+
{
370+
if (options.ClientMaxWindowBits != 15)
371+
{
372+
builder.Append("; ").Append(ClientWebSocketDeflateConstants.ClientMaxWindowBits).Append('=')
373+
.Append(options.ClientMaxWindowBits.ToString(CultureInfo.InvariantCulture));
374+
}
375+
else
376+
{
377+
builder.Append("; ").Append(ClientWebSocketDeflateConstants.ClientMaxWindowBits);
378+
}
379+
}
380+
381+
if (!options.ClientContextTakeover)
382+
{
383+
builder.Append("; ").Append(ClientWebSocketDeflateConstants.ClientNoContextTakeover);
384+
}
385+
386+
if (options.ServerMaxWindowBits != 15)
387+
{
388+
builder.Append("; ")
389+
.Append(ClientWebSocketDeflateConstants.ServerMaxWindowBits).Append('=')
390+
.Append(options.ServerMaxWindowBits.ToString(CultureInfo.InvariantCulture));
391+
}
392+
else
393+
{
394+
builder.Append("; ").Append(ClientWebSocketDeflateConstants.ServerMaxWindowBits);
395+
}
396+
397+
if (!options.ServerContextTakeover)
398+
{
399+
builder.Append("; ").Append(ClientWebSocketDeflateConstants.ServerNoContextTakeover);
400+
}
401+
402+
Debug.Assert(builder.Length <= ClientWebSocketDeflateConstants.MaxExtensionLength);
403+
return builder.ToString();
404+
}
405+
}
229406
}
230407
}
231408
}

src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

44
using System;
5+
using System.Linq;
56
using System.Net;
67
using System.Net.Http;
78
using System.Net.WebSockets;
@@ -632,5 +633,39 @@ public async Task MultipleValueHeadersNotOverridden()
632633
}
633634
}
634635
}
636+
637+
[Fact]
638+
public async Task InitialCompression()
639+
{
640+
await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
641+
{
642+
Assert.True(context.WebSockets.IsWebSocketRequest);
643+
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
644+
}))
645+
{
646+
using (var client = new HttpClient())
647+
{
648+
var uri = new UriBuilder(new Uri($"ws://127.0.0.1:{port}/"));
649+
uri.Scheme = "http";
650+
651+
// Craft a valid WebSocket Upgrade request
652+
using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString()))
653+
{
654+
request.Headers.Connection.Clear();
655+
request.Headers.Connection.Add("Upgrade");
656+
request.Headers.Connection.Add("keep-alive");
657+
request.Headers.Upgrade.Add(new System.Net.Http.Headers.ProductHeaderValue("websocket"));
658+
request.Headers.Add(HeaderNames.SecWebSocketVersion, "13");
659+
// SecWebSocketKey required to be 16 bytes
660+
request.Headers.Add(HeaderNames.SecWebSocketKey, Convert.ToBase64String(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }, Base64FormattingOptions.None));
661+
request.Headers.Add("Sec-WebSocket-Extensions", "permessage-deflate");
662+
663+
var response = await client.SendAsync(request);
664+
Assert.Equal(HttpStatusCode.SwitchingProtocols, response.StatusCode);
665+
Assert.Equal("permessage-deflate; server_max_window_bits", response.Headers.GetValues("Sec-WebSocket-Extensions").Aggregate((l, r) => $"{l}; {r}"));
666+
}
667+
}
668+
}
669+
}
635670
}
636671
}

0 commit comments

Comments
 (0)