Skip to content

Commit c5800e2

Browse files
authored
HTTP/3: Handle request completes with unread body content (#33578)
1 parent 214b734 commit c5800e2

File tree

2 files changed

+138
-12
lines changed

2 files changed

+138
-12
lines changed

src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ internal abstract partial class Http3Stream : HttpProtocol, IHttp3Stream, IHttpH
5050

5151
private TaskCompletionSource? _appCompleted;
5252

53+
private StreamCompletionFlags _completionState;
54+
private readonly object _completionLock = new object();
55+
56+
public bool EndStreamReceived => (_completionState & StreamCompletionFlags.EndStreamReceived) == StreamCompletionFlags.EndStreamReceived;
57+
private bool IsAborted => (_completionState & StreamCompletionFlags.Aborted) == StreamCompletionFlags.Aborted;
58+
internal bool RstStreamReceived => (_completionState & StreamCompletionFlags.RstStreamReceived) == StreamCompletionFlags.RstStreamReceived;
59+
5360
public Pipe RequestBodyPipe { get; }
5461

5562
public Http3Stream(Http3StreamContext context)
@@ -105,8 +112,12 @@ public Http3Stream(Http3StreamContext context)
105112

106113
public void Abort(ConnectionAbortedException abortReason, Http3ErrorCode errorCode)
107114
{
108-
// TODO - Should there be a check here to track abort state to avoid
109-
// running twice for a request?
115+
var (oldState, newState) = ApplyCompletionFlag(StreamCompletionFlags.Aborted);
116+
117+
if (oldState == newState)
118+
{
119+
return;
120+
}
110121

111122
Log.Http3StreamAbort(TraceIdentifier, errorCode, abortReason);
112123

@@ -318,9 +329,37 @@ private static bool IsConnectionSpecificHeaderField(ReadOnlySpan<byte> name, Rea
318329
}
319330

320331
protected override void OnRequestProcessingEnded()
332+
{
333+
CompleteStream(errored: false);
334+
}
335+
336+
private void CompleteStream(bool errored)
321337
{
322338
Debug.Assert(_appCompleted != null);
323339
_appCompleted.SetResult();
340+
341+
if (!EndStreamReceived)
342+
{
343+
if (!errored)
344+
{
345+
Log.RequestBodyNotEntirelyRead(ConnectionIdFeature, TraceIdentifier);
346+
}
347+
348+
var (oldState, newState) = ApplyCompletionFlag(StreamCompletionFlags.Aborted);
349+
if (oldState != newState)
350+
{
351+
// https://quicwg.org/base-drafts/draft-ietf-quic-http.html#section-4.1-15
352+
// When the server does not need to receive the remainder of the request, it MAY abort reading
353+
// the request stream, send a complete response, and cleanly close the sending part of the stream.
354+
// The error code H3_NO_ERROR SHOULD be used when requesting that the client stop sending on the
355+
// request stream.
356+
357+
// TODO(JamesNK): Abort the read half of the stream with H3_NO_ERROR
358+
// https://github.com/dotnet/aspnetcore/issues/33575
359+
360+
RequestBodyPipe.Writer.Complete();
361+
}
362+
}
324363
}
325364

326365
private bool TryClose()
@@ -423,6 +462,8 @@ public async Task ProcessRequestAsync<TContext>(IHttpApplication<TContext> appli
423462

424463
private ValueTask OnEndStreamReceived()
425464
{
465+
ApplyCompletionFlag(StreamCompletionFlags.EndStreamReceived);
466+
426467
if (_requestHeaderParsingState == RequestHeaderParsingState.Ready)
427468
{
428469
// https://quicwg.org/base-drafts/draft-ietf-quic-http.html#section-4.1-14
@@ -552,7 +593,7 @@ private Task ProcessDataFrameAsync(in ReadOnlySequence<byte> payload)
552593
// https://quicwg.org/base-drafts/draft-ietf-quic-http.html#section-4.1
553594
if (_requestHeaderParsingState == RequestHeaderParsingState.Trailers)
554595
{
555-
var message = CoreStrings.FormatHttp3StreamErrorFrameReceivedAfterTrailers(CoreStrings.FormatHttp3StreamErrorFrameReceivedAfterTrailers(Http3Formatting.ToFormattedType(Http3FrameType.Data)));
596+
var message = CoreStrings.FormatHttp3StreamErrorFrameReceivedAfterTrailers(Http3Formatting.ToFormattedType(Http3FrameType.Data));
556597
throw new Http3ConnectionErrorException(message, Http3ErrorCode.UnexpectedFrame);
557598
}
558599

@@ -814,6 +855,17 @@ private Pipe CreateRequestBodyPipe(uint windowSize)
814855
minimumSegmentSize: _context.MemoryPool.GetMinimumSegmentSize()
815856
));
816857

858+
private (StreamCompletionFlags OldState, StreamCompletionFlags NewState) ApplyCompletionFlag(StreamCompletionFlags completionState)
859+
{
860+
lock (_completionLock)
861+
{
862+
var oldCompletionState = _completionState;
863+
_completionState |= completionState;
864+
865+
return (oldCompletionState, _completionState);
866+
}
867+
}
868+
817869
/// <summary>
818870
/// Used to kick off the request processing loop by derived classes.
819871
/// </summary>
@@ -839,6 +891,15 @@ private enum PseudoHeaderFields
839891
Unknown = 0x40000000
840892
}
841893

894+
[Flags]
895+
private enum StreamCompletionFlags
896+
{
897+
None = 0,
898+
RstStreamReceived = 1,
899+
EndStreamReceived = 2,
900+
Aborted = 4,
901+
}
902+
842903
private static class GracefulCloseInitiator
843904
{
844905
public const int None = 0;

src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1738,11 +1738,17 @@ public async Task FrameAfterTrailers_UnexpectedFrameError()
17381738
{
17391739
new KeyValuePair<string, string>("TestName", "TestValue"),
17401740
};
1741-
var requestStream = await InitializeConnectionAndStreamsAsync(_noopApplication);
1741+
var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
1742+
var requestStream = await InitializeConnectionAndStreamsAsync(async c =>
1743+
{
1744+
// Send headers
1745+
await c.Response.Body.FlushAsync();
1746+
1747+
await tcs.Task;
1748+
});
17421749

17431750
await requestStream.SendHeadersAsync(headers, endStream: false);
17441751

1745-
// The app no-ops quickly. Wait for it here so it's not a race with the error response.
17461752
await requestStream.ExpectHeadersAsync();
17471753

17481754
await requestStream.SendDataAsync(Encoding.UTF8.GetBytes("Hello world"));
@@ -1752,6 +1758,8 @@ public async Task FrameAfterTrailers_UnexpectedFrameError()
17521758
await requestStream.WaitForStreamErrorAsync(
17531759
Http3ErrorCode.UnexpectedFrame,
17541760
expectedErrorMessage: CoreStrings.FormatHttp3StreamErrorFrameReceivedAfterTrailers(Http3Formatting.ToFormattedType(Http3FrameType.Data)));
1761+
1762+
tcs.SetResult();
17551763
}
17561764

17571765
[Fact]
@@ -2435,13 +2443,6 @@ await outboundcontrolStream.SendSettingsAsync(new List<Http3PeerSetting>
24352443
Assert.Equal(Internal.Http3.Http3SettingType.MaxFieldSectionSize, maxFieldSetting.Key);
24362444
Assert.Equal(100, maxFieldSetting.Value);
24372445

2438-
var headers = new[]
2439-
{
2440-
new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
2441-
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
2442-
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
2443-
};
2444-
24452446
var requestStream = await CreateRequestStream().DefaultTimeout();
24462447
await requestStream.SendHeadersAsync(new[]
24472448
{
@@ -2453,5 +2454,69 @@ await requestStream.SendHeadersAsync(new[]
24532454

24542455
await requestStream.WaitForStreamErrorAsync(Http3ErrorCode.InternalError, "The encoded HTTP headers length exceeds the limit specified by the peer of 100 bytes.");
24552456
}
2457+
2458+
[Fact]
2459+
public async Task PostRequest_ServerReadsPartialAndFinishes_SendsBodyWithEndStream()
2460+
{
2461+
var startingTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
2462+
var appTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
2463+
var clientTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
2464+
2465+
var requestStream = await InitializeConnectionAndStreamsAsync(async context =>
2466+
{
2467+
var buffer = new byte[1024];
2468+
try
2469+
{
2470+
// Read 100 bytes
2471+
var readCount = 0;
2472+
while (readCount < 100)
2473+
{
2474+
readCount += await context.Request.Body.ReadAsync(buffer.AsMemory(readCount, 100 - readCount));
2475+
}
2476+
2477+
await context.Response.Body.WriteAsync(buffer.AsMemory(0, 100));
2478+
await clientTcs.Task.DefaultTimeout();
2479+
appTcs.SetResult(0);
2480+
}
2481+
catch (Exception ex)
2482+
{
2483+
appTcs.SetException(ex);
2484+
}
2485+
});
2486+
2487+
var sourceData = new byte[1024];
2488+
for (var i = 0; i < sourceData.Length; i++)
2489+
{
2490+
sourceData[i] = (byte)(i % byte.MaxValue);
2491+
}
2492+
2493+
await requestStream.SendHeadersAsync(new[]
2494+
{
2495+
new KeyValuePair<string, string>(HeaderNames.Method, "POST"),
2496+
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
2497+
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
2498+
});
2499+
2500+
await requestStream.SendDataAsync(sourceData);
2501+
var decodedHeaders = await requestStream.ExpectHeadersAsync();
2502+
Assert.Equal(2, decodedHeaders.Count);
2503+
Assert.Equal("200", decodedHeaders[HeaderNames.Status]);
2504+
2505+
var data = await requestStream.ExpectDataAsync();
2506+
2507+
Assert.Equal(sourceData.AsMemory(0, 100).ToArray(), data.ToArray());
2508+
2509+
clientTcs.SetResult(0);
2510+
await appTcs.Task;
2511+
2512+
await requestStream.ExpectReceiveEndOfStream();
2513+
2514+
// TODO(JamesNK): Await the server aborting the sending half of the request stream.
2515+
// https://github.com/dotnet/aspnetcore/issues/33575
2516+
await Task.Delay(1000);
2517+
2518+
// Logged without an exception.
2519+
Assert.Contains(LogMessages, m => m.Message.Contains("the application completed without reading the entire request body."));
2520+
}
24562521
}
24572522
}

0 commit comments

Comments
 (0)