From da6f21b941a2e6d93e46cdb49ce4389a1c5f0c80 Mon Sep 17 00:00:00 2001 From: James Newton-King Date: Wed, 16 Jun 2021 22:07:17 +1200 Subject: [PATCH 1/2] HTTP/3: Handle request completes with unread body content --- .../Core/src/Internal/Http3/Http3Stream.cs | 67 ++++++++++++++++++- .../Http3/Http3StreamTests.cs | 66 ++++++++++++++++-- 2 files changed, 123 insertions(+), 10 deletions(-) diff --git a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs index 3e961282d9c5..487b40dd69a8 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs @@ -50,6 +50,13 @@ internal abstract partial class Http3Stream : HttpProtocol, IHttp3Stream, IHttpH private TaskCompletionSource? _appCompleted; + private StreamCompletionFlags _completionState; + private readonly object _completionLock = new object(); + + public bool EndStreamReceived => (_completionState & StreamCompletionFlags.EndStreamReceived) == StreamCompletionFlags.EndStreamReceived; + private bool IsAborted => (_completionState & StreamCompletionFlags.Aborted) == StreamCompletionFlags.Aborted; + internal bool RstStreamReceived => (_completionState & StreamCompletionFlags.RstStreamReceived) == StreamCompletionFlags.RstStreamReceived; + public Pipe RequestBodyPipe { get; } public Http3Stream(Http3StreamContext context) @@ -105,8 +112,12 @@ public Http3Stream(Http3StreamContext context) public void Abort(ConnectionAbortedException abortReason, Http3ErrorCode errorCode) { - // TODO - Should there be a check here to track abort state to avoid - // running twice for a request? + var (oldState, newState) = ApplyCompletionFlag(StreamCompletionFlags.Aborted); + + if (oldState == newState) + { + return; + } Log.Http3StreamAbort(TraceIdentifier, errorCode, abortReason); @@ -318,9 +329,37 @@ private static bool IsConnectionSpecificHeaderField(ReadOnlySpan name, Rea } protected override void OnRequestProcessingEnded() + { + CompleteStream(errored: false); + } + + private void CompleteStream(bool errored) { Debug.Assert(_appCompleted != null); _appCompleted.SetResult(); + + if (!EndStreamReceived) + { + if (!errored) + { + Log.RequestBodyNotEntirelyRead(ConnectionIdFeature, TraceIdentifier); + } + + var (oldState, newState) = ApplyCompletionFlag(StreamCompletionFlags.Aborted); + if (oldState != newState) + { + // https://quicwg.org/base-drafts/draft-ietf-quic-http.html#section-4.1-15 + // When the server does not need to receive the remainder of the request, it MAY abort reading + // the request stream, send a complete response, and cleanly close the sending part of the stream. + // The error code H3_NO_ERROR SHOULD be used when requesting that the client stop sending on the + // request stream. + + // TODO(JamesNK): Abort the read half of the stream with H3_NO_ERROR + // https://github.com/dotnet/aspnetcore/issues/33575 + + RequestBodyPipe.Writer.Complete(); + } + } } private bool TryClose() @@ -423,6 +462,8 @@ public async Task ProcessRequestAsync(IHttpApplication appli private ValueTask OnEndStreamReceived() { + ApplyCompletionFlag(StreamCompletionFlags.EndStreamReceived); + if (_requestHeaderParsingState == RequestHeaderParsingState.Ready) { // https://quicwg.org/base-drafts/draft-ietf-quic-http.html#section-4.1-14 @@ -552,7 +593,7 @@ private Task ProcessDataFrameAsync(in ReadOnlySequence payload) // https://quicwg.org/base-drafts/draft-ietf-quic-http.html#section-4.1 if (_requestHeaderParsingState == RequestHeaderParsingState.Trailers) { - var message = CoreStrings.FormatHttp3StreamErrorFrameReceivedAfterTrailers(CoreStrings.FormatHttp3StreamErrorFrameReceivedAfterTrailers(Http3Formatting.ToFormattedType(Http3FrameType.Data))); + var message = CoreStrings.FormatHttp3StreamErrorFrameReceivedAfterTrailers(Http3Formatting.ToFormattedType(Http3FrameType.Data)); throw new Http3ConnectionErrorException(message, Http3ErrorCode.UnexpectedFrame); } @@ -814,6 +855,17 @@ private Pipe CreateRequestBodyPipe(uint windowSize) minimumSegmentSize: _context.MemoryPool.GetMinimumSegmentSize() )); + private (StreamCompletionFlags OldState, StreamCompletionFlags NewState) ApplyCompletionFlag(StreamCompletionFlags completionState) + { + lock (_completionLock) + { + var oldCompletionState = _completionState; + _completionState |= completionState; + + return (oldCompletionState, _completionState); + } + } + /// /// Used to kick off the request processing loop by derived classes. /// @@ -839,6 +891,15 @@ private enum PseudoHeaderFields Unknown = 0x40000000 } + [Flags] + private enum StreamCompletionFlags + { + None = 0, + RstStreamReceived = 1, + EndStreamReceived = 2, + Aborted = 4, + } + private static class GracefulCloseInitiator { public const int None = 0; diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs index b9046d6513e7..808ad88ed280 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs @@ -1738,11 +1738,17 @@ public async Task FrameAfterTrailers_UnexpectedFrameError() { new KeyValuePair("TestName", "TestValue"), }; - var requestStream = await InitializeConnectionAndStreamsAsync(_noopApplication); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var requestStream = await InitializeConnectionAndStreamsAsync(async c => + { + // Send headers + await c.Response.Body.FlushAsync(); + + await tcs.Task; + }); await requestStream.SendHeadersAsync(headers, endStream: false); - // The app no-ops quickly. Wait for it here so it's not a race with the error response. await requestStream.ExpectHeadersAsync(); await requestStream.SendDataAsync(Encoding.UTF8.GetBytes("Hello world")); @@ -1752,6 +1758,8 @@ public async Task FrameAfterTrailers_UnexpectedFrameError() await requestStream.WaitForStreamErrorAsync( Http3ErrorCode.UnexpectedFrame, expectedErrorMessage: CoreStrings.FormatHttp3StreamErrorFrameReceivedAfterTrailers(Http3Formatting.ToFormattedType(Http3FrameType.Data))); + + tcs.SetResult(); } [Fact] @@ -2434,24 +2442,68 @@ await outboundcontrolStream.SendSettingsAsync(new List Assert.Equal(Internal.Http3.Http3SettingType.MaxFieldSectionSize, maxFieldSetting.Key); Assert.Equal(100, maxFieldSetting.Value); - + } + + [Fact] + public async Task PostRequest_ServerReadsPartialAndFinishes_SendsBodyWithEndStream() + { + var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var headers = new[] { - new KeyValuePair(HeaderNames.Method, "GET"), new KeyValuePair(HeaderNames.Path, "/"), new KeyValuePair(HeaderNames.Scheme, "http"), }; + var requestStream = await InitializeConnectionAndStreamsAsync(async context => + { + var buffer = new byte[1024]; + try + { + // Read 100 bytes + var readCount = 0; + while (readCount < 100) + { + readCount += await context.Request.Body.ReadAsync(buffer.AsMemory(readCount, 100 - readCount)); + } + await context.Response.Body.WriteAsync(buffer.AsMemory(0, 100)); var requestStream = await CreateRequestStream().DefaultTimeout(); + await clientTcs.Task.DefaultTimeout(); + appTcs.SetResult(0); + } + catch (Exception ex) + { + appTcs.SetException(ex); + } + }); await requestStream.SendHeadersAsync(new[] + var sourceData = new byte[1024]; + for (var i = 0; i < sourceData.Length; i++) { + sourceData[i] = (byte)(i % byte.MaxValue); + } new KeyValuePair(HeaderNames.Path, "/"), - new KeyValuePair(HeaderNames.Scheme, "http"), + await requestStream.SendDataAsync(sourceData); new KeyValuePair(HeaderNames.Method, "GET"), - new KeyValuePair(HeaderNames.Authority, "localhost:80"), + var decodedHeaders = await requestStream.ExpectHeadersAsync(); + Assert.Equal(2, decodedHeaders.Count); + Assert.Equal("200", decodedHeaders[HeaderNames.Status]); }, endStream: true); + var data = await requestStream.ExpectDataAsync(); + + Assert.Equal(sourceData.AsMemory(0, 100).ToArray(), data.ToArray()); + + clientTcs.SetResult(0); + await appTcs.Task; + + await requestStream.ExpectReceiveEndOfStream(); - await requestStream.WaitForStreamErrorAsync(Http3ErrorCode.InternalError, "The encoded HTTP headers length exceeds the limit specified by the peer of 100 bytes."); + // TODO(JamesNK): Await the server aborting the sending half of the request stream. + // https://github.com/dotnet/aspnetcore/issues/33575 + await Task.Delay(1000); + + // Logged without an exception. } } } From 16194b33227b07f133cee060faeb0e9ba8d194bc Mon Sep 17 00:00:00 2001 From: James Newton-King Date: Sat, 19 Jun 2021 18:41:49 +1200 Subject: [PATCH 2/2] Fix --- .../Http3/Http3StreamTests.cs | 33 +++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs index 808ad88ed280..54d9e27a9ead 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs @@ -2442,19 +2442,26 @@ await outboundcontrolStream.SendSettingsAsync(new List Assert.Equal(Internal.Http3.Http3SettingType.MaxFieldSectionSize, maxFieldSetting.Key); Assert.Equal(100, maxFieldSetting.Value); + + var requestStream = await CreateRequestStream().DefaultTimeout(); + await requestStream.SendHeadersAsync(new[] + { + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + new KeyValuePair(HeaderNames.Method, "GET"), + new KeyValuePair(HeaderNames.Authority, "localhost:80"), + }, endStream: true); + + await requestStream.WaitForStreamErrorAsync(Http3ErrorCode.InternalError, "The encoded HTTP headers length exceeds the limit specified by the peer of 100 bytes."); } - + [Fact] public async Task PostRequest_ServerReadsPartialAndFinishes_SendsBodyWithEndStream() { var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var headers = new[] - { - new KeyValuePair(HeaderNames.Path, "/"), - new KeyValuePair(HeaderNames.Scheme, "http"), - }; + var requestStream = await InitializeConnectionAndStreamsAsync(async context => { var buffer = new byte[1024]; @@ -2468,7 +2475,6 @@ public async Task PostRequest_ServerReadsPartialAndFinishes_SendsBodyWithEndStre } await context.Response.Body.WriteAsync(buffer.AsMemory(0, 100)); - var requestStream = await CreateRequestStream().DefaultTimeout(); await clientTcs.Task.DefaultTimeout(); appTcs.SetResult(0); } @@ -2477,19 +2483,25 @@ public async Task PostRequest_ServerReadsPartialAndFinishes_SendsBodyWithEndStre appTcs.SetException(ex); } }); - await requestStream.SendHeadersAsync(new[] + var sourceData = new byte[1024]; for (var i = 0; i < sourceData.Length; i++) { sourceData[i] = (byte)(i % byte.MaxValue); } + + await requestStream.SendHeadersAsync(new[] + { + new KeyValuePair(HeaderNames.Method, "POST"), new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }); + await requestStream.SendDataAsync(sourceData); - new KeyValuePair(HeaderNames.Method, "GET"), var decodedHeaders = await requestStream.ExpectHeadersAsync(); Assert.Equal(2, decodedHeaders.Count); Assert.Equal("200", decodedHeaders[HeaderNames.Status]); - }, endStream: true); + var data = await requestStream.ExpectDataAsync(); Assert.Equal(sourceData.AsMemory(0, 100).ToArray(), data.ToArray()); @@ -2504,6 +2516,7 @@ await requestStream.SendHeadersAsync(new[] await Task.Delay(1000); // Logged without an exception. + Assert.Contains(LogMessages, m => m.Message.Contains("the application completed without reading the entire request body.")); } } }