Skip to content

Commit 32d98c6

Browse files
committed
HTTP/3: Handle request completes with unread body content
1 parent 3e3f51a commit 32d98c6

File tree

2 files changed

+138
-5
lines changed

2 files changed

+138
-5
lines changed

src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ internal abstract partial class Http3Stream : HttpProtocol, IHttp3Stream, IHttpH
5050

5151
private TaskCompletionSource? _appCompleted;
5252

53+
private StreamCompletionFlags _completionState;
54+
private readonly object _completionLock = new object();
55+
56+
public bool EndStreamReceived => (_completionState & StreamCompletionFlags.EndStreamReceived) == StreamCompletionFlags.EndStreamReceived;
57+
private bool IsAborted => (_completionState & StreamCompletionFlags.Aborted) == StreamCompletionFlags.Aborted;
58+
internal bool RstStreamReceived => (_completionState & StreamCompletionFlags.RstStreamReceived) == StreamCompletionFlags.RstStreamReceived;
59+
5360
public Pipe RequestBodyPipe { get; }
5461

5562
public Http3Stream(Http3StreamContext context)
@@ -103,8 +110,12 @@ public Http3Stream(Http3StreamContext context)
103110

104111
public void Abort(ConnectionAbortedException abortReason, Http3ErrorCode errorCode)
105112
{
106-
// TODO - Should there be a check here to track abort state to avoid
107-
// running twice for a request?
113+
var (oldState, newState) = ApplyCompletionFlag(StreamCompletionFlags.Aborted);
114+
115+
if (oldState == newState)
116+
{
117+
return;
118+
}
108119

109120
Log.Http3StreamAbort(TraceIdentifier, errorCode, abortReason);
110121

@@ -316,9 +327,36 @@ private static bool IsConnectionSpecificHeaderField(ReadOnlySpan<byte> name, Rea
316327
}
317328

318329
protected override void OnRequestProcessingEnded()
330+
{
331+
CompleteStream(errored: false);
332+
}
333+
334+
private void CompleteStream(bool errored)
319335
{
320336
Debug.Assert(_appCompleted != null);
321337
_appCompleted.SetResult();
338+
339+
if (!EndStreamReceived)
340+
{
341+
if (!errored)
342+
{
343+
Log.RequestBodyNotEntirelyRead(ConnectionIdFeature, TraceIdentifier);
344+
}
345+
346+
var (oldState, newState) = ApplyCompletionFlag(StreamCompletionFlags.Aborted);
347+
if (oldState != newState)
348+
{
349+
// https://quicwg.org/base-drafts/draft-ietf-quic-http.html#section-4.1-15
350+
// When the server does not need to receive the remainder of the request, it MAY abort reading
351+
// the request stream, send a complete response, and cleanly close the sending part of the stream.
352+
// The error code H3_NO_ERROR SHOULD be used when requesting that the client stop sending on the
353+
// request stream.
354+
355+
// TODO(JamesNK): Abort the read half of the stream with H3_NO_ERROR
356+
357+
RequestBodyPipe.Writer.Complete();
358+
}
359+
}
322360
}
323361

324362
private bool TryClose()
@@ -421,6 +459,8 @@ public async Task ProcessRequestAsync<TContext>(IHttpApplication<TContext> appli
421459

422460
private ValueTask OnEndStreamReceived()
423461
{
462+
ApplyCompletionFlag(StreamCompletionFlags.EndStreamReceived);
463+
424464
if (_requestHeaderParsingState == RequestHeaderParsingState.Ready)
425465
{
426466
// https://quicwg.org/base-drafts/draft-ietf-quic-http.html#section-4.1-14
@@ -542,7 +582,7 @@ private Task ProcessDataFrameAsync(in ReadOnlySequence<byte> payload)
542582
// https://quicwg.org/base-drafts/draft-ietf-quic-http.html#section-4.1
543583
if (_requestHeaderParsingState == RequestHeaderParsingState.Trailers)
544584
{
545-
var message = CoreStrings.FormatHttp3StreamErrorFrameReceivedAfterTrailers(CoreStrings.FormatHttp3StreamErrorFrameReceivedAfterTrailers(Http3Formatting.ToFormattedType(Http3FrameType.Data)));
585+
var message = CoreStrings.FormatHttp3StreamErrorFrameReceivedAfterTrailers(Http3Formatting.ToFormattedType(Http3FrameType.Data));
546586
throw new Http3ConnectionErrorException(message, Http3ErrorCode.UnexpectedFrame);
547587
}
548588

@@ -804,6 +844,17 @@ private Pipe CreateRequestBodyPipe(uint windowSize)
804844
minimumSegmentSize: _context.MemoryPool.GetMinimumSegmentSize()
805845
));
806846

847+
private (StreamCompletionFlags OldState, StreamCompletionFlags NewState) ApplyCompletionFlag(StreamCompletionFlags completionState)
848+
{
849+
lock (_completionLock)
850+
{
851+
var oldCompletionState = _completionState;
852+
_completionState |= completionState;
853+
854+
return (oldCompletionState, _completionState);
855+
}
856+
}
857+
807858
/// <summary>
808859
/// Used to kick off the request processing loop by derived classes.
809860
/// </summary>
@@ -829,6 +880,15 @@ private enum PseudoHeaderFields
829880
Unknown = 0x40000000
830881
}
831882

883+
[Flags]
884+
private enum StreamCompletionFlags
885+
{
886+
None = 0,
887+
RstStreamReceived = 1,
888+
EndStreamReceived = 2,
889+
Aborted = 4,
890+
}
891+
832892
private static class GracefulCloseInitiator
833893
{
834894
public const int None = 0;

src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1738,11 +1738,17 @@ public async Task FrameAfterTrailers_UnexpectedFrameError()
17381738
{
17391739
new KeyValuePair<string, string>("TestName", "TestValue"),
17401740
};
1741-
var requestStream = await InitializeConnectionAndStreamsAsync(_noopApplication);
1741+
var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
1742+
var requestStream = await InitializeConnectionAndStreamsAsync(async c =>
1743+
{
1744+
// Send headers
1745+
await c.Response.Body.FlushAsync();
1746+
1747+
await tcs.Task;
1748+
});
17421749

17431750
await requestStream.SendHeadersAsync(headers, endStream: false);
17441751

1745-
// The app no-ops quickly. Wait for it here so it's not a race with the error response.
17461752
await requestStream.ExpectHeadersAsync();
17471753

17481754
await requestStream.SendDataAsync(Encoding.UTF8.GetBytes("Hello world"));
@@ -1752,6 +1758,8 @@ public async Task FrameAfterTrailers_UnexpectedFrameError()
17521758
await requestStream.WaitForStreamErrorAsync(
17531759
Http3ErrorCode.UnexpectedFrame,
17541760
expectedErrorMessage: CoreStrings.FormatHttp3StreamErrorFrameReceivedAfterTrailers(Http3Formatting.ToFormattedType(Http3FrameType.Data)));
1761+
1762+
tcs.SetResult();
17551763
}
17561764

17571765
[Fact]
@@ -2414,5 +2422,70 @@ await requestStream.SendHeadersAsync(new[]
24142422
await requestStream.ExpectHeadersAsync().DefaultTimeout();
24152423
await requestStream.ExpectReceiveEndOfStream().DefaultTimeout();
24162424
}
2425+
2426+
[Fact]
2427+
public async Task PostRequest_ServerReadsPartialAndFinishes_SendsBodyWithEndStream()
2428+
{
2429+
var startingTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
2430+
var appTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
2431+
var clientTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
2432+
var headers = new[]
2433+
{
2434+
new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
2435+
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
2436+
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
2437+
};
2438+
var requestStream = await InitializeConnectionAndStreamsAsync(async context =>
2439+
{
2440+
var buffer = new byte[1024];
2441+
try
2442+
{
2443+
// Read 100 bytes
2444+
var readCount = 0;
2445+
while (readCount < 100)
2446+
{
2447+
readCount += await context.Request.Body.ReadAsync(buffer.AsMemory(readCount, 100 - readCount));
2448+
}
2449+
2450+
await context.Response.Body.WriteAsync(buffer.AsMemory(0, 100));
2451+
2452+
await clientTcs.Task.DefaultTimeout();
2453+
appTcs.SetResult(0);
2454+
}
2455+
catch (Exception ex)
2456+
{
2457+
appTcs.SetException(ex);
2458+
}
2459+
});
2460+
2461+
var sourceData = new byte[1024];
2462+
for (var i = 0; i < sourceData.Length; i++)
2463+
{
2464+
sourceData[i] = (byte)(i % byte.MaxValue);
2465+
}
2466+
2467+
await requestStream.SendHeadersAsync(headers, endStream: false);
2468+
await requestStream.SendDataAsync(sourceData);
2469+
2470+
var decodedHeaders = await requestStream.ExpectHeadersAsync();
2471+
Assert.Equal(2, decodedHeaders.Count);
2472+
Assert.Contains("date", decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
2473+
Assert.Equal("200", decodedHeaders[HeaderNames.Status]);
2474+
2475+
var data = await requestStream.ExpectDataAsync();
2476+
2477+
Assert.Equal(sourceData.AsMemory(0, 100).ToArray(), data.ToArray());
2478+
2479+
clientTcs.SetResult(0);
2480+
await appTcs.Task;
2481+
2482+
await requestStream.ExpectReceiveEndOfStream();
2483+
2484+
// TODO(JamesNK): Await the server aborting the sending half of the request stream.
2485+
await Task.Delay(1000);
2486+
2487+
// Logged without an exception.
2488+
Assert.Contains(LogMessages, m => m.Message.Contains("the application completed without reading the entire request body."));
2489+
}
24172490
}
24182491
}

0 commit comments

Comments
 (0)