Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 64 additions & 3 deletions src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ internal abstract partial class Http3Stream : HttpProtocol, IHttp3Stream, IHttpH

private TaskCompletionSource? _appCompleted;

private StreamCompletionFlags _completionState;
private readonly object _completionLock = new object();

public bool EndStreamReceived => (_completionState & StreamCompletionFlags.EndStreamReceived) == StreamCompletionFlags.EndStreamReceived;
private bool IsAborted => (_completionState & StreamCompletionFlags.Aborted) == StreamCompletionFlags.Aborted;
internal bool RstStreamReceived => (_completionState & StreamCompletionFlags.RstStreamReceived) == StreamCompletionFlags.RstStreamReceived;

public Pipe RequestBodyPipe { get; }

public Http3Stream(Http3StreamContext context)
Expand Down Expand Up @@ -105,8 +112,12 @@ public Http3Stream(Http3StreamContext context)

public void Abort(ConnectionAbortedException abortReason, Http3ErrorCode errorCode)
{
// TODO - Should there be a check here to track abort state to avoid
// running twice for a request?
var (oldState, newState) = ApplyCompletionFlag(StreamCompletionFlags.Aborted);

if (oldState == newState)
{
return;
}

Log.Http3StreamAbort(TraceIdentifier, errorCode, abortReason);

Expand Down Expand Up @@ -318,9 +329,37 @@ private static bool IsConnectionSpecificHeaderField(ReadOnlySpan<byte> name, Rea
}

protected override void OnRequestProcessingEnded()
{
CompleteStream(errored: false);
}

private void CompleteStream(bool errored)
{
Debug.Assert(_appCompleted != null);
_appCompleted.SetResult();

if (!EndStreamReceived)
{
if (!errored)
{
Log.RequestBodyNotEntirelyRead(ConnectionIdFeature, TraceIdentifier);
}

var (oldState, newState) = ApplyCompletionFlag(StreamCompletionFlags.Aborted);
if (oldState != newState)
{
// https://quicwg.org/base-drafts/draft-ietf-quic-http.html#section-4.1-15
// When the server does not need to receive the remainder of the request, it MAY abort reading
// the request stream, send a complete response, and cleanly close the sending part of the stream.
// The error code H3_NO_ERROR SHOULD be used when requesting that the client stop sending on the
// request stream.

// TODO(JamesNK): Abort the read half of the stream with H3_NO_ERROR
// https://github.com/dotnet/aspnetcore/issues/33575

RequestBodyPipe.Writer.Complete();
}
}
}

private bool TryClose()
Expand Down Expand Up @@ -423,6 +462,8 @@ public async Task ProcessRequestAsync<TContext>(IHttpApplication<TContext> appli

private ValueTask OnEndStreamReceived()
{
ApplyCompletionFlag(StreamCompletionFlags.EndStreamReceived);

if (_requestHeaderParsingState == RequestHeaderParsingState.Ready)
{
// https://quicwg.org/base-drafts/draft-ietf-quic-http.html#section-4.1-14
Expand Down Expand Up @@ -552,7 +593,7 @@ private Task ProcessDataFrameAsync(in ReadOnlySequence<byte> payload)
// https://quicwg.org/base-drafts/draft-ietf-quic-http.html#section-4.1
if (_requestHeaderParsingState == RequestHeaderParsingState.Trailers)
{
var message = CoreStrings.FormatHttp3StreamErrorFrameReceivedAfterTrailers(CoreStrings.FormatHttp3StreamErrorFrameReceivedAfterTrailers(Http3Formatting.ToFormattedType(Http3FrameType.Data)));
var message = CoreStrings.FormatHttp3StreamErrorFrameReceivedAfterTrailers(Http3Formatting.ToFormattedType(Http3FrameType.Data));
throw new Http3ConnectionErrorException(message, Http3ErrorCode.UnexpectedFrame);
}

Expand Down Expand Up @@ -814,6 +855,17 @@ private Pipe CreateRequestBodyPipe(uint windowSize)
minimumSegmentSize: _context.MemoryPool.GetMinimumSegmentSize()
));

private (StreamCompletionFlags OldState, StreamCompletionFlags NewState) ApplyCompletionFlag(StreamCompletionFlags completionState)
{
lock (_completionLock)
{
var oldCompletionState = _completionState;
_completionState |= completionState;

return (oldCompletionState, _completionState);
}
}

/// <summary>
/// Used to kick off the request processing loop by derived classes.
/// </summary>
Expand All @@ -839,6 +891,15 @@ private enum PseudoHeaderFields
Unknown = 0x40000000
}

[Flags]
private enum StreamCompletionFlags
{
None = 0,
RstStreamReceived = 1,
EndStreamReceived = 2,
Aborted = 4,
}

private static class GracefulCloseInitiator
{
public const int None = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1738,11 +1738,17 @@ public async Task FrameAfterTrailers_UnexpectedFrameError()
{
new KeyValuePair<string, string>("TestName", "TestValue"),
};
var requestStream = await InitializeConnectionAndStreamsAsync(_noopApplication);
var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
var requestStream = await InitializeConnectionAndStreamsAsync(async c =>
{
// Send headers
await c.Response.Body.FlushAsync();

await tcs.Task;
});

await requestStream.SendHeadersAsync(headers, endStream: false);

// The app no-ops quickly. Wait for it here so it's not a race with the error response.
await requestStream.ExpectHeadersAsync();

await requestStream.SendDataAsync(Encoding.UTF8.GetBytes("Hello world"));
Expand All @@ -1752,6 +1758,8 @@ public async Task FrameAfterTrailers_UnexpectedFrameError()
await requestStream.WaitForStreamErrorAsync(
Http3ErrorCode.UnexpectedFrame,
expectedErrorMessage: CoreStrings.FormatHttp3StreamErrorFrameReceivedAfterTrailers(Http3Formatting.ToFormattedType(Http3FrameType.Data)));

tcs.SetResult();
}

[Fact]
Expand Down Expand Up @@ -2435,13 +2443,6 @@ await outboundcontrolStream.SendSettingsAsync(new List<Http3PeerSetting>
Assert.Equal(Internal.Http3.Http3SettingType.MaxFieldSectionSize, maxFieldSetting.Key);
Assert.Equal(100, maxFieldSetting.Value);

var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
};

var requestStream = await CreateRequestStream().DefaultTimeout();
await requestStream.SendHeadersAsync(new[]
{
Expand All @@ -2453,5 +2454,69 @@ await requestStream.SendHeadersAsync(new[]

await requestStream.WaitForStreamErrorAsync(Http3ErrorCode.InternalError, "The encoded HTTP headers length exceeds the limit specified by the peer of 100 bytes.");
}

[Fact]
public async Task PostRequest_ServerReadsPartialAndFinishes_SendsBodyWithEndStream()
{
var startingTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var appTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var clientTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);

var requestStream = await InitializeConnectionAndStreamsAsync(async context =>
{
var buffer = new byte[1024];
try
{
// Read 100 bytes
var readCount = 0;
while (readCount < 100)
{
readCount += await context.Request.Body.ReadAsync(buffer.AsMemory(readCount, 100 - readCount));
}

await context.Response.Body.WriteAsync(buffer.AsMemory(0, 100));
await clientTcs.Task.DefaultTimeout();
appTcs.SetResult(0);
}
catch (Exception ex)
{
appTcs.SetException(ex);
}
});

var sourceData = new byte[1024];
for (var i = 0; i < sourceData.Length; i++)
{
sourceData[i] = (byte)(i % byte.MaxValue);
}

await requestStream.SendHeadersAsync(new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "POST"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
});

await requestStream.SendDataAsync(sourceData);
var decodedHeaders = await requestStream.ExpectHeadersAsync();
Assert.Equal(2, decodedHeaders.Count);
Assert.Equal("200", decodedHeaders[HeaderNames.Status]);

var data = await requestStream.ExpectDataAsync();

Assert.Equal(sourceData.AsMemory(0, 100).ToArray(), data.ToArray());

clientTcs.SetResult(0);
await appTcs.Task;

await requestStream.ExpectReceiveEndOfStream();

// TODO(JamesNK): Await the server aborting the sending half of the request stream.
// https://github.com/dotnet/aspnetcore/issues/33575
await Task.Delay(1000);

// Logged without an exception.
Assert.Contains(LogMessages, m => m.Message.Contains("the application completed without reading the entire request body."));
}
}
}