diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestStream.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestStream.cs index 4d5513293e66..819777d65494 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestStream.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestStream.cs @@ -3,11 +3,13 @@ using System; using System.Buffers; +using System.Collections.Generic; using System.IO; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { @@ -15,6 +17,7 @@ internal sealed class HttpRequestStream : Stream { private readonly HttpRequestPipeReader _pipeReader; private readonly IHttpBodyControlFeature _bodyControl; + private AsyncEnumerableReader _asyncReader; public HttpRequestStream(IHttpBodyControlFeature bodyControl, HttpRequestPipeReader pipeReader) { @@ -44,12 +47,26 @@ public override int WriteTimeout public override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) { - return ReadAsyncWrapper(destination, cancellationToken); + try + { + return ReadAsyncInternal(destination, cancellationToken); + } + catch (ConnectionAbortedException ex) + { + throw new TaskCanceledException("The request was aborted", ex); + } } public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - return ReadAsyncWrapper(new Memory(buffer, offset, count), cancellationToken).AsTask(); + try + { + return ReadAsyncInternal(new Memory(buffer, offset, count), cancellationToken).AsTask(); + } + catch (ConnectionAbortedException ex) + { + throw new TaskCanceledException("The request was aborted", ex); + } } public override int Read(byte[] buffer, int offset, int count) @@ -127,23 +144,78 @@ private Task ReadAsync(byte[] buffer, int offset, int count, CancellationTo return tcs.Task; } - private ValueTask ReadAsyncWrapper(Memory destination, CancellationToken cancellationToken) + private ValueTask ReadAsyncInternal(Memory buffer, CancellationToken cancellationToken) { + if (_asyncReader?.InProgress ?? false) + { + // Throw if there are overlapping reads; throwing unwrapped as it suggests last read was not awaited + // so we surface it directly rather than wrapped in a Task (as this one will likely also not be awaited). + throw new InvalidOperationException("Concurrent reads are not supported; await the " + nameof(ValueTask) + " before starting next read."); + } + try { - return ReadAsyncInternal(destination, cancellationToken); + while (true) + { + if (!_pipeReader.TryRead(out var result)) + { + break; + } + + if (result.IsCanceled) + { + throw new OperationCanceledException("The read was canceled"); + } + + var readableBuffer = result.Buffer; + var readableBufferLength = readableBuffer.Length; + + var consumed = readableBuffer.End; + var actual = 0; + try + { + if (readableBufferLength != 0) + { + actual = (int)Math.Min(readableBufferLength, buffer.Length); + + var slice = actual == readableBufferLength ? readableBuffer : readableBuffer.Slice(0, actual); + consumed = slice.End; + slice.CopyTo(buffer.Span); + + return new ValueTask(actual); + } + + if (result.IsCompleted) + { + return new ValueTask(0); + } + } + finally + { + _pipeReader.AdvanceTo(consumed); + } + } } - catch (ConnectionAbortedException ex) + catch (Exception ex) { - throw new TaskCanceledException("The request was aborted", ex); + return new ValueTask(Task.FromException(ex)); + } + + var asyncReader = _asyncReader; + if (asyncReader is null) + { + _asyncReader = asyncReader = new AsyncEnumerableReader(); + asyncReader.Initialize(ReadAsyncAwaited(asyncReader)); } + + return asyncReader.ReadAsync(buffer, cancellationToken); } - private async ValueTask ReadAsyncInternal(Memory buffer, CancellationToken cancellationToken) + private async IAsyncEnumerable ReadAsyncAwaited(AsyncEnumerableReader reader) { while (true) { - var result = await _pipeReader.ReadAsync(cancellationToken); + var result = await _pipeReader.ReadAsync(reader.CancellationToken); if (result.IsCanceled) { @@ -154,30 +226,40 @@ private async ValueTask ReadAsyncInternal(Memory buffer, Cancellation var readableBufferLength = readableBuffer.Length; var consumed = readableBuffer.End; + var advanced = false; try { if (readableBufferLength != 0) { - var actual = (int)Math.Min(readableBufferLength, buffer.Length); + var actual = (int)Math.Min(readableBufferLength, reader.Buffer.Length); var slice = actual == readableBufferLength ? readableBuffer : readableBuffer.Slice(0, actual); consumed = slice.End; - slice.CopyTo(buffer.Span); + slice.CopyTo(reader.Buffer.Span); - return actual; + // Finally blocks in enumerators aren't excuted prior to the yield return, + // so we advance here + advanced = true; + _pipeReader.AdvanceTo(consumed); + yield return actual; } - - if (result.IsCompleted) + else if (result.IsCompleted) { - return 0; + // Finally blocks in enumerators aren't excuted prior to the yield return, + // so we advance here + advanced = true; + _pipeReader.AdvanceTo(consumed); + yield return 0; } } finally { - _pipeReader.AdvanceTo(consumed); + if (!advanced) + { + _pipeReader.AdvanceTo(consumed); + } } } - } /// diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs index 7edd16713857..7dad03647508 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs @@ -69,11 +69,23 @@ public override void AdvanceTo(SequencePosition consumed, SequencePosition exami public override bool TryRead(out ReadResult readResult) { - var result = _context.RequestBodyPipe.Reader.TryRead(out readResult); - _readResult = readResult; - CountBytesRead(readResult.Buffer.Length); + TryStart(); + + var hasResult = _context.RequestBodyPipe.Reader.TryRead(out readResult); + + if (hasResult) + { + _readResult = readResult; + + CountBytesRead(readResult.Buffer.Length); + + if (readResult.IsCompleted) + { + TryStop(); + } + } - return result; + return hasResult; } public override async ValueTask ReadAsync(CancellationToken cancellationToken = default) diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/AsyncEnumerableReader.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/AsyncEnumerableReader.cs new file mode 100644 index 000000000000..73fcd72b13f0 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/AsyncEnumerableReader.cs @@ -0,0 +1,135 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Tasks.Sources; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + internal class AsyncEnumerableReader : IValueTaskSource + { + private readonly Action _onCompletedAction; + + private ManualResetValueTaskSourceCore _valueTaskSource; + private IAsyncEnumerable _readerSource; + private IAsyncEnumerator _reader; + + public Memory Buffer { get; private set; } + public CancellationToken CancellationToken { get; private set; } + + private ValueTaskAwaiter _readAwaiter; + + private volatile bool _inProgress; + public bool InProgress => _inProgress; + + public AsyncEnumerableReader() + { + _onCompletedAction = OnCompleted; + } + + internal void Initialize(IAsyncEnumerable readerSource) + { + _readerSource = readerSource; + _reader = readerSource.GetAsyncEnumerator(); + } + + public ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken) + { + if (_readerSource is null) + { + ThrowNotInitialized(); + } + + if (_inProgress) + { + ThrowConcurrentReadsNotSupported(); + } + _inProgress = true; + + Buffer = buffer; + CancellationToken = cancellationToken; + + var task = _reader.MoveNextAsync(); + _readAwaiter = task.GetAwaiter(); + + return new ValueTask(this, _valueTaskSource.Version); + } + + int IValueTaskSource.GetResult(short token) + { + var isValid = token == _valueTaskSource.Version; + try + { + return _valueTaskSource.GetResult(token); + } + finally + { + if (isValid) + { + Buffer = default; + CancellationToken = default; + _inProgress = false; + _valueTaskSource.Reset(); + } + } + } + + ValueTaskSourceStatus IValueTaskSource.GetStatus(short token) + => _valueTaskSource.GetStatus(token); + + void IValueTaskSource.OnCompleted(Action continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) + { + if (!InProgress) + { + ThrowNoReadInProgress(); + } + + _valueTaskSource.OnCompleted(continuation, state, token, flags); + + _readAwaiter.UnsafeOnCompleted(_onCompletedAction); + } + + private void OnCompleted() + { + try + { + if (_readAwaiter.GetResult()) + { + _valueTaskSource.SetResult(_reader.Current); + } + else + { + _valueTaskSource.SetResult(-1); + } + } + catch (Exception ex) + { + // If the GetResult throws for this ReadAsync (e.g. cancellation), + // that will cause all next ReadAsyncs to also throw, so we create + // a fresh unerrored AsyncEnumerable to restore the next ReadAsyncs + // to the normal flow + _reader = _readerSource.GetAsyncEnumerator(); + _valueTaskSource.SetException(ex); + } + } + + static void ThrowConcurrentReadsNotSupported() + { + throw new InvalidOperationException("Concurrent reads are not supported"); + } + + static void ThrowNoReadInProgress() + { + throw new InvalidOperationException("No read in progress, await will not complete"); + } + + static void ThrowNotInitialized() + { + throw new InvalidOperationException(nameof(AsyncEnumerableReader) + " has not been initialized"); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Middleware/Internal/DuplexPipeStream.cs b/src/Servers/Kestrel/Core/src/Middleware/Internal/DuplexPipeStream.cs index b81a5e886973..23e86dd3b5a3 100644 --- a/src/Servers/Kestrel/Core/src/Middleware/Internal/DuplexPipeStream.cs +++ b/src/Servers/Kestrel/Core/src/Middleware/Internal/DuplexPipeStream.cs @@ -7,6 +7,8 @@ using System.Threading; using System.Threading.Tasks; using System.Buffers; +using System.Collections.Generic; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal { @@ -17,6 +19,8 @@ internal class DuplexPipeStream : Stream private readonly bool _throwOnCancelled; private volatile bool _cancelCalled; + private AsyncEnumerableReader _asyncReader; + public DuplexPipeStream(PipeReader input, PipeWriter output, bool throwOnCancelled = false) { _input = input; @@ -114,12 +118,77 @@ public override Task FlushAsync(CancellationToken cancellationToken) return WriteAsync(null, 0, 0, cancellationToken); } - private async ValueTask ReadAsyncInternal(Memory destination, CancellationToken cancellationToken) + private ValueTask ReadAsyncInternal(Memory destination, CancellationToken cancellationToken) + { + if (_asyncReader?.InProgress ?? false) + { + // Throw if there are overlapping reads; throwing unwrapped as it suggests last read was not awaited + // so we surface it directly rather than wrapped in a Task (as this one will likely also not be awaited). + throw new InvalidOperationException("Concurrent reads are not supported; await the " + nameof(ValueTask) + " before starting next read."); + } + + try + { + while (true) + { + if (!_input.TryRead(out var result)) + { + break; + } + + var readableBuffer = result.Buffer; + try + { + if (_throwOnCancelled && result.IsCanceled && _cancelCalled) + { + // Reset the bool + _cancelCalled = false; + throw new OperationCanceledException(); + } + + if (!readableBuffer.IsEmpty) + { + // buffer.Count is int + var count = (int)Math.Min(readableBuffer.Length, destination.Length); + readableBuffer = readableBuffer.Slice(0, count); + readableBuffer.CopyTo(destination.Span); + return new ValueTask(count); + } + + if (result.IsCompleted) + { + return new ValueTask(0); + } + } + finally + { + _input.AdvanceTo(readableBuffer.End, readableBuffer.End); + } + } + } + catch (Exception ex) + { + return new ValueTask(Task.FromException(ex)); + } + + var asyncReader = _asyncReader; + if (asyncReader is null) + { + _asyncReader = asyncReader = new AsyncEnumerableReader(); + asyncReader.Initialize(ReadAsyncAwaited(asyncReader)); + } + + return asyncReader.ReadAsync(destination, cancellationToken); + } + + private async IAsyncEnumerable ReadAsyncAwaited(AsyncEnumerableReader reader) { while (true) { - var result = await _input.ReadAsync(cancellationToken); + var result = await _input.ReadAsync(reader.CancellationToken); var readableBuffer = result.Buffer; + + var advanced = false; try { if (_throwOnCancelled && result.IsCanceled && _cancelCalled) @@ -128,24 +197,34 @@ private async ValueTask ReadAsyncInternal(Memory destination, Cancell _cancelCalled = false; throw new OperationCanceledException(); } - - if (!readableBuffer.IsEmpty) + else if (!readableBuffer.IsEmpty) { // buffer.Count is int - var count = (int)Math.Min(readableBuffer.Length, destination.Length); + var count = (int)Math.Min(readableBuffer.Length, reader.Buffer.Length); readableBuffer = readableBuffer.Slice(0, count); - readableBuffer.CopyTo(destination.Span); - return count; - } + readableBuffer.CopyTo(reader.Buffer.Span); - if (result.IsCompleted) + // Finally blocks in enumerators aren't excuted prior to the yield return, + // so we advance here + advanced = true; + _input.AdvanceTo(readableBuffer.End, readableBuffer.End); + yield return count; + } + else if (result.IsCompleted) { - return 0; + // Finally blocks in enumerators aren't excuted prior to the yield return, + // so we advance here + advanced = true; + _input.AdvanceTo(readableBuffer.End, readableBuffer.End); + yield return 0; } } finally { - _input.AdvanceTo(readableBuffer.End, readableBuffer.End); + if (!advanced) + { + _input.AdvanceTo(readableBuffer.End, readableBuffer.End); + } } } }