Skip to content

Commit 71c1b11

Browse files
committed
fix and some tests
1 parent 3694a62 commit 71c1b11

File tree

2 files changed

+37
-33
lines changed

2 files changed

+37
-33
lines changed

src/Middleware/WebSockets/src/WebSocketMiddleware.cs

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -161,14 +161,15 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
161161
if (ext.Count != 0)
162162
{
163163
var decline = false;
164-
foreach (var extension in ext)
164+
// loop over each extension offer, extensions can have multiple offers we can accept any
165+
foreach (var extension in _context.Request.Headers.GetCommaSeparatedValues("Sec-WebSocket-Extensions"))
165166
{
166167
if (extension.TrimStart().StartsWith(ClientWebSocketDeflateConstants.Extension, StringComparison.Ordinal))
167168
{
168169
deflateOptions = new();
169170
if (ParseDeflateOptions(extension, deflateOptions, out var hasClientMaxWindowBits))
170171
{
171-
Resp(_context.Response.Headers, deflateOptions, hasClientMaxWindowBits);
172+
WriteDeflatNegotiateResponseHeader(_context.Response.Headers, deflateOptions, hasClientMaxWindowBits);
172173
decline = false;
173174
break;
174175
}
@@ -180,7 +181,8 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
180181
}
181182
if (decline)
182183
{
183-
throw new InvalidOperationException("'permessage-deflate' extension not accepted.");
184+
// TODO: Do we care?
185+
throw new WebSocketException(WebSocketError.HeaderError, "'permessage-deflate' extension not accepted.");
184186
}
185187
}
186188

@@ -312,7 +314,8 @@ private static bool ParseDeflateOptions(ReadOnlySpan<char> extension, WebSocketD
312314
{
313315
return false;
314316
}
315-
options.ClientMaxWindowBits = clientMaxWindowBits;
317+
// if client didn't send a value for ClientMaxWindowBits use the value the server set
318+
options.ClientMaxWindowBits = clientMaxWindowBits ?? options.ClientMaxWindowBits;
316319
}
317320
else if (value.StartsWith(ClientWebSocketDeflateConstants.ServerMaxWindowBits))
318321
{
@@ -321,19 +324,25 @@ private static bool ParseDeflateOptions(ReadOnlySpan<char> extension, WebSocketD
321324
{
322325
return false;
323326
}
324-
options.ServerMaxWindowBits = serverMaxWindowBits;
327+
// if client didn't send a value for ServerMaxWindowBits use the value the server set
328+
options.ServerMaxWindowBits = serverMaxWindowBits ?? options.ServerMaxWindowBits;
325329
}
326330

327-
static int ParseWindowBits(ReadOnlySpan<char> value)
331+
static int? ParseWindowBits(ReadOnlySpan<char> value)
328332
{
329-
// parameters can be sent without a value by the client
330333
var startIndex = value.IndexOf('=');
331334

332-
if (startIndex < 0 ||
333-
!int.TryParse(value[(startIndex + 1)..], NumberStyles.Integer, CultureInfo.InvariantCulture, out int windowBits) ||
335+
// parameters can be sent without a value by the client
336+
if (startIndex < 0)
337+
{
338+
return null;
339+
}
340+
341+
if (!int.TryParse(value[(startIndex + 1)..], NumberStyles.Integer, CultureInfo.InvariantCulture, out int windowBits) ||
334342
windowBits < 9 ||
335343
windowBits > 15)
336344
{
345+
// TODO
337346
throw new WebSocketException(WebSocketError.HeaderError, "");
338347
}
339348

@@ -351,7 +360,7 @@ static int ParseWindowBits(ReadOnlySpan<char> value)
351360
return true;
352361
}
353362

354-
private static void Resp(IHeaderDictionary headers, WebSocketDeflateOptions options, bool hasClientMaxWindowBits)
363+
private static void WriteDeflatNegotiateResponseHeader(IHeaderDictionary headers, WebSocketDeflateOptions options, bool hasClientMaxWindowBits)
355364
{
356365
headers.Add("Sec-WebSocket-Extensions", GetDeflateOptions(options, hasClientMaxWindowBits));
357366

@@ -367,32 +376,19 @@ static string GetDeflateOptions(WebSocketDeflateOptions options, bool hasClientM
367376
// https://tools.ietf.org/html/rfc7692#section-7.1.2.2
368377
if (hasClientMaxWindowBits)
369378
{
370-
if (options.ClientMaxWindowBits != 15)
371-
{
372-
builder.Append("; ").Append(ClientWebSocketDeflateConstants.ClientMaxWindowBits).Append('=')
373-
.Append(options.ClientMaxWindowBits.ToString(CultureInfo.InvariantCulture));
374-
}
375-
else
376-
{
377-
builder.Append("; ").Append(ClientWebSocketDeflateConstants.ClientMaxWindowBits);
378-
}
379+
builder.Append("; ").Append(ClientWebSocketDeflateConstants.ClientMaxWindowBits).Append('=')
380+
.Append(options.ClientMaxWindowBits.ToString(CultureInfo.InvariantCulture));
379381
}
380382

381383
if (!options.ClientContextTakeover)
382384
{
383385
builder.Append("; ").Append(ClientWebSocketDeflateConstants.ClientNoContextTakeover);
384386
}
385387

386-
if (options.ServerMaxWindowBits != 15)
387-
{
388-
builder.Append("; ")
389-
.Append(ClientWebSocketDeflateConstants.ServerMaxWindowBits).Append('=')
390-
.Append(options.ServerMaxWindowBits.ToString(CultureInfo.InvariantCulture));
391-
}
392-
else
393-
{
394-
builder.Append("; ").Append(ClientWebSocketDeflateConstants.ServerMaxWindowBits);
395-
}
388+
// TODO: we could avoid sending this in some cases
389+
builder.Append("; ")
390+
.Append(ClientWebSocketDeflateConstants.ServerMaxWindowBits).Append('=')
391+
.Append(options.ServerMaxWindowBits.ToString(CultureInfo.InvariantCulture));
396392

397393
if (!options.ServerContextTakeover)
398394
{

src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -634,8 +634,16 @@ public async Task MultipleValueHeadersNotOverridden()
634634
}
635635
}
636636

637-
[Fact]
638-
public async Task InitialCompression()
637+
[Theory]
638+
[InlineData("permessage-deflate", "permessage-deflate; server_max_window_bits=15")]
639+
[InlineData("permessage-deflate; server_no_context_takeover", "permessage-deflate; server_max_window_bits=15; server_no_context_takeover")]
640+
[InlineData("permessage-deflate; client_no_context_takeover", "permessage-deflate; client_no_context_takeover; server_max_window_bits=15")]
641+
[InlineData("permessage-deflate; client_max_window_bits=9", "permessage-deflate; client_max_window_bits=9; server_max_window_bits=15")]
642+
[InlineData("permessage-deflate; client_max_window_bits", "permessage-deflate; client_max_window_bits=15; server_max_window_bits=15")]
643+
[InlineData("permessage-deflate; server_max_window_bits", "permessage-deflate; server_max_window_bits=15")]
644+
[InlineData("permessage-deflate; server_max_window_bits=10", "permessage-deflate; server_max_window_bits=10")]
645+
[InlineData("permessage-deflate; server_max_window_bits=10; server_no_context_takeover", "permessage-deflate; server_max_window_bits=10; server_no_context_takeover")]
646+
public async Task CompressionNegotiationProducesCorrectHeader(string clientHeader, string expectedResponse)
639647
{
640648
await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
641649
{
@@ -658,11 +666,11 @@ public async Task InitialCompression()
658666
request.Headers.Add(HeaderNames.SecWebSocketVersion, "13");
659667
// SecWebSocketKey required to be 16 bytes
660668
request.Headers.Add(HeaderNames.SecWebSocketKey, Convert.ToBase64String(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }, Base64FormattingOptions.None));
661-
request.Headers.Add("Sec-WebSocket-Extensions", "permessage-deflate");
669+
request.Headers.Add("Sec-WebSocket-Extensions", clientHeader);
662670

663671
var response = await client.SendAsync(request);
664672
Assert.Equal(HttpStatusCode.SwitchingProtocols, response.StatusCode);
665-
Assert.Equal("permessage-deflate; server_max_window_bits", response.Headers.GetValues("Sec-WebSocket-Extensions").Aggregate((l, r) => $"{l}; {r}"));
673+
Assert.Equal(expectedResponse, response.Headers.GetValues("Sec-WebSocket-Extensions").Aggregate((l, r) => $"{l}; {r}"));
666674
}
667675
}
668676
}

0 commit comments

Comments
 (0)